aboutsummaryrefslogtreecommitdiff
path: root/source/opt
diff options
context:
space:
mode:
Diffstat (limited to 'source/opt')
-rw-r--r--source/opt/CMakeLists.txt10
-rw-r--r--source/opt/aggressive_dead_code_elim_pass.cpp28
-rw-r--r--source/opt/amd_ext_to_khr.cpp24
-rw-r--r--source/opt/amd_ext_to_khr.h2
-rw-r--r--source/opt/basic_block.h2
-rw-r--r--source/opt/ccp_pass.cpp44
-rw-r--r--source/opt/ccp_pass.h16
-rw-r--r--source/opt/cfg.h2
-rw-r--r--source/opt/compact_ids_pass.cpp3
-rw-r--r--source/opt/const_folding_rules.cpp128
-rw-r--r--source/opt/constants.cpp20
-rw-r--r--source/opt/constants.h13
-rw-r--r--source/opt/convert_to_half_pass.cpp2
-rw-r--r--source/opt/copy_prop_arrays.cpp8
-rw-r--r--source/opt/copy_prop_arrays.h2
-rw-r--r--source/opt/dead_branch_elim_pass.cpp2
-rw-r--r--source/opt/dead_branch_elim_pass.h4
-rw-r--r--source/opt/debug_info_manager.cpp2
-rw-r--r--source/opt/def_use_manager.cpp221
-rw-r--r--source/opt/def_use_manager.h102
-rw-r--r--source/opt/desc_sroa.cpp11
-rw-r--r--source/opt/desc_sroa.h11
-rw-r--r--source/opt/dominator_tree.cpp4
-rw-r--r--source/opt/dominator_tree.h2
-rw-r--r--source/opt/eliminate_dead_input_components_pass.cpp146
-rw-r--r--source/opt/eliminate_dead_input_components_pass.h59
-rw-r--r--source/opt/eliminate_dead_members_pass.cpp4
-rw-r--r--source/opt/feature_manager.cpp3
-rw-r--r--source/opt/fold.cpp2
-rw-r--r--source/opt/fold_spec_constant_op_and_composite_pass.cpp2
-rw-r--r--source/opt/folding_rules.cpp4
-rw-r--r--source/opt/graphics_robust_access_pass.cpp16
-rw-r--r--source/opt/graphics_robust_access_pass.h2
-rw-r--r--source/opt/inst_bindless_check_pass.cpp7
-rw-r--r--source/opt/inst_bindless_check_pass.h4
-rw-r--r--source/opt/inst_debug_printf_pass.cpp12
-rw-r--r--source/opt/instruction.cpp10
-rw-r--r--source/opt/instruction.h28
-rw-r--r--source/opt/instrument_pass.h6
-rw-r--r--source/opt/interp_fixup_pass.cpp7
-rw-r--r--source/opt/ir_context.cpp17
-rw-r--r--source/opt/ir_context.h34
-rw-r--r--source/opt/ir_loader.cpp3
-rw-r--r--source/opt/local_access_chain_convert_pass.cpp13
-rw-r--r--source/opt/local_access_chain_convert_pass.h2
-rw-r--r--source/opt/local_single_block_elim_pass.cpp13
-rw-r--r--source/opt/local_single_store_elim_pass.cpp13
-rw-r--r--source/opt/loop_descriptor.cpp6
-rw-r--r--source/opt/loop_descriptor.h2
-rw-r--r--source/opt/loop_fission.cpp4
-rw-r--r--source/opt/loop_fission.h2
-rw-r--r--source/opt/loop_fusion.cpp4
-rw-r--r--source/opt/loop_fusion.h2
-rw-r--r--source/opt/loop_fusion_pass.h2
-rw-r--r--source/opt/loop_peeling.h2
-rw-r--r--source/opt/loop_unroller.cpp10
-rw-r--r--source/opt/loop_unswitch_pass.cpp2
-rw-r--r--source/opt/loop_unswitch_pass.h2
-rw-r--r--source/opt/loop_utils.h4
-rw-r--r--source/opt/merge_return_pass.cpp1
-rw-r--r--source/opt/merge_return_pass.h4
-rw-r--r--source/opt/module.cpp6
-rw-r--r--source/opt/module.h6
-rw-r--r--source/opt/optimizer.cpp33
-rw-r--r--source/opt/pass.h9
-rw-r--r--source/opt/pass_manager.cpp14
-rw-r--r--source/opt/pass_manager.h4
-rw-r--r--source/opt/passes.h5
-rw-r--r--source/opt/private_to_local_pass.cpp2
-rw-r--r--source/opt/private_to_local_pass.h2
-rw-r--r--source/opt/redundancy_elimination.h2
-rw-r--r--source/opt/register_pressure.cpp2
-rw-r--r--source/opt/remove_dontinline_pass.cpp49
-rw-r--r--source/opt/remove_dontinline_pass.h42
-rw-r--r--source/opt/remove_duplicates_pass.cpp5
-rw-r--r--source/opt/replace_desc_array_access_using_var_index.cpp6
-rw-r--r--source/opt/replace_desc_array_access_using_var_index.h4
-rw-r--r--source/opt/replace_invalid_opc.cpp5
-rw-r--r--source/opt/scalar_analysis.cpp2
-rw-r--r--source/opt/scalar_analysis_nodes.h2
-rw-r--r--source/opt/scalar_analysis_simplification.cpp4
-rw-r--r--source/opt/scalar_replacement_pass.cpp25
-rw-r--r--source/opt/scalar_replacement_pass.h10
-rw-r--r--source/opt/spread_volatile_semantics.cpp318
-rw-r--r--source/opt/spread_volatile_semantics.h117
-rw-r--r--source/opt/strength_reduction_pass.h2
-rw-r--r--source/opt/strip_debug_info_pass.cpp13
-rw-r--r--source/opt/strip_nonsemantic_info_pass.cpp (renamed from source/opt/strip_reflect_info_pass.cpp)47
-rw-r--r--source/opt/strip_nonsemantic_info_pass.h (renamed from source/opt/strip_reflect_info_pass.h)10
-rw-r--r--source/opt/type_manager.cpp11
-rw-r--r--source/opt/type_manager.h7
-rw-r--r--source/opt/types.cpp165
-rw-r--r--source/opt/types.h109
-rw-r--r--source/opt/unify_const_pass.cpp2
-rw-r--r--source/opt/upgrade_memory_model.cpp8
-rw-r--r--source/opt/vector_dce.h2
96 files changed, 1535 insertions, 632 deletions
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 7d522fb5..61e7a981 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -45,6 +45,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
eliminate_dead_constant_pass.h
eliminate_dead_functions_pass.h
eliminate_dead_functions_util.h
+ eliminate_dead_input_components_pass.h
eliminate_dead_members_pass.h
empty_pass.h
feature_manager.h
@@ -99,6 +100,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
reflect.h
register_pressure.h
relax_float_ops_pass.h
+ remove_dontinline_pass.h
remove_duplicates_pass.h
remove_unused_interface_variables_pass.h
replace_desc_array_access_using_var_index.h
@@ -108,10 +110,11 @@ set(SPIRV_TOOLS_OPT_SOURCES
scalar_replacement_pass.h
set_spec_constant_default_value_pass.h
simplification_pass.h
+ spread_volatile_semantics.h
ssa_rewrite_pass.h
strength_reduction_pass.h
strip_debug_info_pass.h
- strip_reflect_info_pass.h
+ strip_nonsemantic_info_pass.h
struct_cfg_analysis.h
tree_iterator.h
type_manager.h
@@ -156,6 +159,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
eliminate_dead_constant_pass.cpp
eliminate_dead_functions_pass.cpp
eliminate_dead_functions_util.cpp
+ eliminate_dead_input_components_pass.cpp
eliminate_dead_members_pass.cpp
feature_manager.cpp
fix_storage_class.cpp
@@ -206,6 +210,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
redundancy_elimination.cpp
register_pressure.cpp
relax_float_ops_pass.cpp
+ remove_dontinline_pass.cpp
remove_duplicates_pass.cpp
remove_unused_interface_variables_pass.cpp
replace_desc_array_access_using_var_index.cpp
@@ -215,10 +220,11 @@ set(SPIRV_TOOLS_OPT_SOURCES
scalar_replacement_pass.cpp
set_spec_constant_default_value_pass.cpp
simplification_pass.cpp
+ spread_volatile_semantics.cpp
ssa_rewrite_pass.cpp
strength_reduction_pass.cpp
strip_debug_info_pass.cpp
- strip_reflect_info_pass.cpp
+ strip_nonsemantic_info_pass.cpp
struct_cfg_analysis.cpp
type_manager.cpp
types.cpp
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 0b54d5e8..04737521 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -27,6 +27,7 @@
#include "source/opt/iterator.h"
#include "source/opt/reflect.h"
#include "source/spirv_constant.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -146,8 +147,7 @@ void AggressiveDCEPass::AddStores(Function* func, uint32_t ptrId) {
bool AggressiveDCEPass::AllExtensionsSupported() const {
// If any extension not in allowlist, return false
for (auto& ei : get_module()->extensions()) {
- const char* extName =
- reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
+ const std::string extName = ei.GetInOperand(0).AsString();
if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
return false;
}
@@ -156,11 +156,9 @@ bool AggressiveDCEPass::AllExtensionsSupported() const {
for (auto& inst : context()->module()->ext_inst_imports()) {
assert(inst.opcode() == SpvOpExtInstImport &&
"Expecting an import of an extension's instruction set.");
- const char* extension_name =
- reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]);
- if (0 == std::strncmp(extension_name, "NonSemantic.", 12) &&
- 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100",
- 32)) {
+ const std::string extension_name = inst.GetInOperand(0).AsString();
+ if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
+ extension_name != "NonSemantic.Shader.DebugInfo.100") {
return false;
}
}
@@ -569,12 +567,7 @@ void AggressiveDCEPass::InitializeModuleScopeLiveInstructions() {
}
// Keep all entry points.
for (auto& entry : get_module()->entry_points()) {
- if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4) &&
- !preserve_interface_) {
- // In SPIR-V 1.4 and later, entry points must list all global variables
- // used. DCE can still remove non-input/output variables and update the
- // interface list. Mark the entry point as live and inputs and outputs as
- // live, but defer decisions all other interfaces.
+ if (!preserve_interface_) {
live_insts_.Set(entry.unique_id());
// The actual function is live always.
AddToWorklist(
@@ -582,8 +575,9 @@ void AggressiveDCEPass::InitializeModuleScopeLiveInstructions() {
for (uint32_t i = 3; i < entry.NumInOperands(); ++i) {
auto* var = get_def_use_mgr()->GetDef(entry.GetSingleWordInOperand(i));
auto storage_class = var->GetSingleWordInOperand(0u);
- if (storage_class == SpvStorageClassInput ||
- storage_class == SpvStorageClassOutput) {
+ // Vulkan support outputs without an associated input, but not inputs
+ // without an associated output.
+ if (storage_class == SpvStorageClassOutput) {
AddToWorklist(var);
}
}
@@ -885,8 +879,7 @@ bool AggressiveDCEPass::ProcessGlobalValues() {
}
}
- if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4) &&
- !preserve_interface_) {
+ if (!preserve_interface_) {
// Remove the dead interface variables from the entry point interface list.
for (auto& entry : get_module()->entry_points()) {
std::vector<Operand> new_operands;
@@ -974,6 +967,7 @@ void AggressiveDCEPass::InitExtensions() {
"SPV_KHR_integer_dot_product",
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
+ "SPV_KHR_uniform_group_instructions",
});
}
diff --git a/source/opt/amd_ext_to_khr.cpp b/source/opt/amd_ext_to_khr.cpp
index ccedc0bc..dd9bafda 100644
--- a/source/opt/amd_ext_to_khr.cpp
+++ b/source/opt/amd_ext_to_khr.cpp
@@ -584,9 +584,9 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
}
// Get the constants that will be used.
- uint32_t f0_const_id = const_mgr->GetFloatConst(0.0);
- uint32_t f2_const_id = const_mgr->GetFloatConst(2.0);
- uint32_t f0_5_const_id = const_mgr->GetFloatConst(0.5);
+ uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
+ uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
+ uint32_t f0_5_const_id = const_mgr->GetFloatConstId(0.5);
const analysis::Constant* vec_const =
const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id});
uint32_t vec_const_id =
@@ -731,12 +731,12 @@ bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
}
// Get the constants that will be used.
- uint32_t f0_const_id = const_mgr->GetFloatConst(0.0);
- uint32_t f1_const_id = const_mgr->GetFloatConst(1.0);
- uint32_t f2_const_id = const_mgr->GetFloatConst(2.0);
- uint32_t f3_const_id = const_mgr->GetFloatConst(3.0);
- uint32_t f4_const_id = const_mgr->GetFloatConst(4.0);
- uint32_t f5_const_id = const_mgr->GetFloatConst(5.0);
+ uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0);
+ uint32_t f1_const_id = const_mgr->GetFloatConstId(1.0);
+ uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0);
+ uint32_t f3_const_id = const_mgr->GetFloatConstId(3.0);
+ uint32_t f4_const_id = const_mgr->GetFloatConstId(4.0);
+ uint32_t f5_const_id = const_mgr->GetFloatConstId(5.0);
// Extract the input values.
Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0});
@@ -935,8 +935,7 @@ Pass::Status AmdExtensionToKhrPass::Process() {
std::vector<Instruction*> to_be_killed;
for (Instruction& inst : context()->module()->extensions()) {
if (inst.opcode() == SpvOpExtension) {
- if (ext_to_remove.count(reinterpret_cast<const char*>(
- &(inst.GetInOperand(0).words[0]))) != 0) {
+ if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
to_be_killed.push_back(&inst);
}
}
@@ -944,8 +943,7 @@ Pass::Status AmdExtensionToKhrPass::Process() {
for (Instruction& inst : context()->ext_inst_imports()) {
if (inst.opcode() == SpvOpExtInstImport) {
- if (ext_to_remove.count(reinterpret_cast<const char*>(
- &(inst.GetInOperand(0).words[0]))) != 0) {
+ if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) {
to_be_killed.push_back(&inst);
}
}
diff --git a/source/opt/amd_ext_to_khr.h b/source/opt/amd_ext_to_khr.h
index fd3dab4e..6a39d953 100644
--- a/source/opt/amd_ext_to_khr.h
+++ b/source/opt/amd_ext_to_khr.h
@@ -23,7 +23,7 @@ namespace spvtools {
namespace opt {
// Replaces the extensions VK_AMD_shader_ballot, VK_AMD_gcn_shader, and
-// VK_AMD_shader_trinary_minmax with equivalant code using core instructions and
+// VK_AMD_shader_trinary_minmax with equivalent code using core instructions and
// capabilities.
class AmdExtensionToKhrPass : public Pass {
public:
diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h
index 6741a50f..dd3b2e28 100644
--- a/source/opt/basic_block.h
+++ b/source/opt/basic_block.h
@@ -83,7 +83,7 @@ class BasicBlock {
const Instruction* GetMergeInst() const;
Instruction* GetMergeInst();
- // Returns the OpLoopMerge instruciton in this basic block, if it exists.
+ // Returns the OpLoopMerge instruction in this basic block, if it exists.
// Otherwise return null. May be used whenever tail() can be used.
const Instruction* GetLoopMergeInst() const;
Instruction* GetLoopMergeInst();
diff --git a/source/opt/ccp_pass.cpp b/source/opt/ccp_pass.cpp
index 8b896d50..5f855027 100644
--- a/source/opt/ccp_pass.cpp
+++ b/source/opt/ccp_pass.cpp
@@ -102,6 +102,34 @@ SSAPropagator::PropStatus CCPPass::VisitPhi(Instruction* phi) {
return SSAPropagator::kInteresting;
}
+uint32_t CCPPass::ComputeLatticeMeet(Instruction* instr, uint32_t val2) {
+ // Given two values val1 and val2, the meet operation in the constant
+ // lattice uses the following rules:
+ //
+ // meet(val1, UNDEFINED) = val1
+ // meet(val1, VARYING) = VARYING
+ // meet(val1, val2) = val1 if val1 == val2
+ // meet(val1, val2) = VARYING if val1 != val2
+ //
+ // When two different values meet, the result is always varying because CCP
+ // does not allow lateral transitions in the lattice. This prevents
+ // infinite cycles during propagation.
+ auto val1_it = values_.find(instr->result_id());
+ if (val1_it == values_.end()) {
+ return val2;
+ }
+
+ uint32_t val1 = val1_it->second;
+ if (IsVaryingValue(val1)) {
+ return val1;
+ } else if (IsVaryingValue(val2)) {
+ return val2;
+ } else if (val1 != val2) {
+ return kVaryingSSAId;
+ }
+ return val2;
+}
+
SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) {
assert(instr->result_id() != 0 &&
"Expecting an instruction that produces a result");
@@ -115,8 +143,10 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) {
if (IsVaryingValue(it->second)) {
return MarkInstructionVarying(instr);
} else {
- values_[instr->result_id()] = it->second;
- return SSAPropagator::kInteresting;
+ uint32_t new_val = ComputeLatticeMeet(instr, it->second);
+ values_[instr->result_id()] = new_val;
+ return IsVaryingValue(new_val) ? SSAPropagator::kVarying
+ : SSAPropagator::kInteresting;
}
}
return SSAPropagator::kNotInteresting;
@@ -142,9 +172,13 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) {
if (folded_inst != nullptr) {
// We do not want to change the body of the function by adding new
// instructions. When folding we can only generate new constants.
- assert(folded_inst->IsConstant() && "CCP is only interested in constant.");
- values_[instr->result_id()] = folded_inst->result_id();
- return SSAPropagator::kInteresting;
+ assert((folded_inst->IsConstant() ||
+ IsSpecConstantInst(folded_inst->opcode())) &&
+ "CCP is only interested in constant values.");
+ uint32_t new_val = ComputeLatticeMeet(instr, folded_inst->result_id());
+ values_[instr->result_id()] = new_val;
+ return IsVaryingValue(new_val) ? SSAPropagator::kVarying
+ : SSAPropagator::kInteresting;
}
// Conservatively mark this instruction as varying if any input id is varying.
diff --git a/source/opt/ccp_pass.h b/source/opt/ccp_pass.h
index fb20c780..77ea9f80 100644
--- a/source/opt/ccp_pass.h
+++ b/source/opt/ccp_pass.h
@@ -92,6 +92,22 @@ class CCPPass : public MemPass {
// generated during propagation.
analysis::ConstantManager* const_mgr_;
+ // Returns a new value for |instr| by computing the meet operation between
+ // its existing value and |val2|.
+ //
+ // Given two values val1 and val2, the meet operation in the constant
+ // lattice uses the following rules:
+ //
+ // meet(val1, UNDEFINED) = val1
+ // meet(val1, VARYING) = VARYING
+ // meet(val1, val2) = val1 if val1 == val2
+ // meet(val1, val2) = VARYING if val1 != val2
+ //
+ // When two different values meet, the result is always varying because CCP
+ // does not allow lateral transitions in the lattice. This prevents
+ // infinite cycles during propagation.
+ uint32_t ComputeLatticeMeet(Instruction* instr, uint32_t val2);
+
// Constant value table. Each entry <id, const_decl_id> in this map
// represents the compile-time constant value for |id| as declared by
// |const_decl_id|. Each |const_decl_id| in this table is an OpConstant
diff --git a/source/opt/cfg.h b/source/opt/cfg.h
index f2806822..33412f18 100644
--- a/source/opt/cfg.h
+++ b/source/opt/cfg.h
@@ -30,7 +30,7 @@ class CFG {
public:
explicit CFG(Module* module);
- // Return the list of predecesors for basic block with label |blkid|.
+ // Return the list of predecessors for basic block with label |blkid|.
// TODO(dnovillo): Move this to BasicBlock.
const std::vector<uint32_t>& preds(uint32_t blk_id) const {
assert(label2preds_.count(blk_id));
diff --git a/source/opt/compact_ids_pass.cpp b/source/opt/compact_ids_pass.cpp
index 8815b8c6..70848d79 100644
--- a/source/opt/compact_ids_pass.cpp
+++ b/source/opt/compact_ids_pass.cpp
@@ -86,7 +86,8 @@ Pass::Status CompactIdsPass::Process() {
},
true);
- if (modified) {
+ if (context()->module()->id_bound() != result_id_mapping.size() + 1) {
+ modified = true;
context()->module()->SetIdBound(
static_cast<uint32_t>(result_id_mapping.size() + 1));
// There are ids in the feature manager that could now be invalid
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 515a3ed5..249e11e5 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -22,6 +22,45 @@ namespace {
const uint32_t kExtractCompositeIdInIdx = 0;
+// Returns a constants with the value NaN of the given type. Only works for
+// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
+const analysis::Constant* GetNan(const analysis::Type* type,
+ analysis::ConstantManager* const_mgr) {
+ const analysis::Float* float_type = type->AsFloat();
+ if (float_type == nullptr) {
+ return nullptr;
+ }
+
+ switch (float_type->width()) {
+ case 32:
+ return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN());
+ case 64:
+ return const_mgr->GetDoubleConst(
+ std::numeric_limits<double>::quiet_NaN());
+ default:
+ return nullptr;
+ }
+}
+
+// Returns a constants with the value INF of the given type. Only works for
+// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
+const analysis::Constant* GetInf(const analysis::Type* type,
+ analysis::ConstantManager* const_mgr) {
+ const analysis::Float* float_type = type->AsFloat();
+ if (float_type == nullptr) {
+ return nullptr;
+ }
+
+ switch (float_type->width()) {
+ case 32:
+ return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity());
+ case 64:
+ return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity());
+ default:
+ return nullptr;
+ }
+}
+
// Returns true if |type| is Float or a vector of Float.
bool HasFloatingPoint(const analysis::Type* type) {
if (type->AsFloat()) {
@@ -33,6 +72,23 @@ bool HasFloatingPoint(const analysis::Type* type) {
return false;
}
+// Returns a constants with the value |-val| of the given type. Only works for
+// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
+const analysis::Constant* negateFPConst(const analysis::Type* result_type,
+ const analysis::Constant* val,
+ analysis::ConstantManager* const_mgr) {
+ const analysis::Float* float_type = result_type->AsFloat();
+ assert(float_type != nullptr);
+ if (float_type->width() == 32) {
+ float fa = val->GetFloat();
+ return const_mgr->GetFloatConst(-fa);
+ } else if (float_type->width() == 64) {
+ double da = val->GetDouble();
+ return const_mgr->GetDoubleConst(-da);
+ }
+ return nullptr;
+}
+
// Folds an OpcompositeExtract where input is a composite constant.
ConstantFoldingRule FoldExtractWithConstants() {
return [](IRContext* context, Instruction* inst,
@@ -492,7 +548,60 @@ ConstantFoldingRule FoldQuantizeToF16() {
ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); }
ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); }
ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); }
-ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); }
+
+// Returns the constant that results from evaluating |numerator| / 0.0. Returns
+// |nullptr| if the result could not be evaluated.
+const analysis::Constant* FoldFPScalarDivideByZero(
+ const analysis::Type* result_type, const analysis::Constant* numerator,
+ analysis::ConstantManager* const_mgr) {
+ if (numerator == nullptr) {
+ return nullptr;
+ }
+
+ if (numerator->IsZero()) {
+ return GetNan(result_type, const_mgr);
+ }
+
+ const analysis::Constant* result = GetInf(result_type, const_mgr);
+ if (result == nullptr) {
+ return nullptr;
+ }
+
+ if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) {
+ result = negateFPConst(result_type, result, const_mgr);
+ }
+ return result;
+}
+
+// Returns the result of folding |numerator| / |denominator|. Returns |nullptr|
+// if it cannot be folded.
+const analysis::Constant* FoldScalarFPDivide(
+ const analysis::Type* result_type, const analysis::Constant* numerator,
+ const analysis::Constant* denominator,
+ analysis::ConstantManager* const_mgr) {
+ if (denominator == nullptr) {
+ return nullptr;
+ }
+
+ if (denominator->IsZero()) {
+ return FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
+ }
+
+ const analysis::FloatConstant* denominator_float =
+ denominator->AsFloatConstant();
+ if (denominator_float && denominator->GetValueAsDouble() == -0.0) {
+ const analysis::Constant* result =
+ FoldFPScalarDivideByZero(result_type, numerator, const_mgr);
+ if (result != nullptr)
+ result = negateFPConst(result_type, result, const_mgr);
+ return result;
+ } else {
+ return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr);
+ }
+}
+
+// Returns the constant folding rule to fold |OpFDiv| with two constants.
+ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); }
bool CompareFloatingPoint(bool op_result, bool op_unordered,
bool need_ordered) {
@@ -655,20 +764,7 @@ UnaryScalarFoldingRule FoldFNegateOp() {
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
assert(result_type == a->type());
- const analysis::Float* float_type = result_type->AsFloat();
- assert(float_type != nullptr);
- if (float_type->width() == 32) {
- float fa = a->GetFloat();
- utils::FloatProxy<float> result(-fa);
- std::vector<uint32_t> words = result.GetWords();
- return const_mgr->GetConstant(result_type, words);
- } else if (float_type->width() == 64) {
- double da = a->GetDouble();
- utils::FloatProxy<double> result(-da);
- std::vector<uint32_t> words = result.GetWords();
- return const_mgr->GetConstant(result_type, words);
- }
- return nullptr;
+ return negateFPConst(result_type, a, const_mgr);
};
}
@@ -1250,7 +1346,7 @@ void ConstantFoldingRules::AddFoldingRules() {
FoldFPUnaryOp(FoldFTranscendentalUnary(std::log)));
#ifdef __ANDROID__
- // Android NDK r15c tageting ABI 15 doesn't have full support for C++11
+ // Android NDK r15c targeting ABI 15 doesn't have full support for C++11
// (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't
// available up until ABI 18 so we use a shim
auto log2_shim = [](double v) -> double { return log(v) / log(2.0); };
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index 020e248b..bcff08c1 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -158,6 +158,7 @@ Type* ConstantManager::GetType(const Instruction* inst) const {
std::vector<const Constant*> ConstantManager::GetOperandConstants(
const Instruction* inst) const {
std::vector<const Constant*> constants;
+ constants.reserve(inst->NumInOperands());
for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
const Operand* operand = &inst->GetInOperand(i);
if (operand->type != SPV_OPERAND_TYPE_ID) {
@@ -420,13 +421,30 @@ const Constant* ConstantManager::GetNumericVectorConstantWithWords(
return GetConstant(type, element_ids);
}
-uint32_t ConstantManager::GetFloatConst(float val) {
+uint32_t ConstantManager::GetFloatConstId(float val) {
+ const Constant* c = GetFloatConst(val);
+ return GetDefiningInstruction(c)->result_id();
+}
+
+const Constant* ConstantManager::GetFloatConst(float val) {
Type* float_type = context()->get_type_mgr()->GetFloatType();
utils::FloatProxy<float> v(val);
const Constant* c = GetConstant(float_type, v.GetWords());
+ return c;
+}
+
+uint32_t ConstantManager::GetDoubleConstId(double val) {
+ const Constant* c = GetDoubleConst(val);
return GetDefiningInstruction(c)->result_id();
}
+const Constant* ConstantManager::GetDoubleConst(double val) {
+ Type* float_type = context()->get_type_mgr()->GetDoubleType();
+ utils::FloatProxy<double> v(val);
+ const Constant* c = GetConstant(float_type, v.GetWords());
+ return c;
+}
+
uint32_t ConstantManager::GetSIntConst(int32_t val) {
Type* sint_type = context()->get_type_mgr()->GetSIntType();
const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
diff --git a/source/opt/constants.h b/source/opt/constants.h
index 52bd809a..c039ae08 100644
--- a/source/opt/constants.h
+++ b/source/opt/constants.h
@@ -541,7 +541,7 @@ class ConstantManager {
// instruction at the end of the current module's types section.
//
// |type_id| is an optional argument for disambiguating equivalent types. If
- // |type_id| is specified, the contant returned will have that type id.
+ // |type_id| is specified, the constant returned will have that type id.
Instruction* GetDefiningInstruction(const Constant* c, uint32_t type_id = 0,
Module::inst_iterator* pos = nullptr);
@@ -637,7 +637,16 @@ class ConstantManager {
}
// Returns the id of a 32-bit floating point constant with value |val|.
- uint32_t GetFloatConst(float val);
+ uint32_t GetFloatConstId(float val);
+
+ // Returns a 32-bit float constant with the given value.
+ const Constant* GetFloatConst(float val);
+
+ // Returns the id of a 64-bit floating point constant with value |val|.
+ uint32_t GetDoubleConstId(double val);
+
+ // Returns a 64-bit float constant with the given value.
+ const Constant* GetDoubleConst(double val);
// Returns the id of a 32-bit signed integer constant with value |val|.
uint32_t GetSIntConst(int32_t val);
diff --git a/source/opt/convert_to_half_pass.cpp b/source/opt/convert_to_half_pass.cpp
index b127eabe..4086e31a 100644
--- a/source/opt/convert_to_half_pass.cpp
+++ b/source/opt/convert_to_half_pass.cpp
@@ -181,7 +181,7 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
uint32_t to_width) {
// Add converts of any float operands to to_width if they are of from_width.
// If converting to 16, change type of phi to float16 equivalent and remember
- // result id. Converts need to be added to preceeding blocks.
+ // result id. Converts need to be added to preceding blocks.
uint32_t ocnt = 0;
uint32_t* prev_idp;
bool modified = false;
diff --git a/source/opt/copy_prop_arrays.cpp b/source/opt/copy_prop_arrays.cpp
index 62ed5e77..321d4969 100644
--- a/source/opt/copy_prop_arrays.cpp
+++ b/source/opt/copy_prop_arrays.cpp
@@ -745,11 +745,11 @@ void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
context()->AnalyzeUses(use);
}
break;
+ case SpvOpDecorate:
+ // We treat an OpImageTexelPointer as a load. The result type should
+ // always have the Image storage class, and should not need to be
+ // updated.
case SpvOpImageTexelPointer:
- // We treat an OpImageTexelPointer as a load. The result type should
- // always have the Image storage class, and should not need to be
- // updated.
-
// Replace the actual use.
context()->ForgetUses(use);
use->SetOperand(index, {new_ptr_inst->result_id()});
diff --git a/source/opt/copy_prop_arrays.h b/source/opt/copy_prop_arrays.h
index f4314a74..46a508cf 100644
--- a/source/opt/copy_prop_arrays.h
+++ b/source/opt/copy_prop_arrays.h
@@ -35,7 +35,7 @@ namespace opt {
//
// The hard part is keeping all of the types correct. We do not want to
// have to do too large a search to update everything, which may not be
-// possible, do we give up if we see any instruction that might be hard to
+// possible, so we give up if we see any instruction that might be hard to
// update.
class CopyPropagateArrays : public MemPass {
diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp
index 356dbcb3..cc616ca6 100644
--- a/source/opt/dead_branch_elim_pass.cpp
+++ b/source/opt/dead_branch_elim_pass.cpp
@@ -207,7 +207,7 @@ bool DeadBranchElimPass::SimplifyBranch(BasicBlock* block,
Instruction::OperandList new_operands;
new_operands.push_back(terminator->GetInOperand(0));
new_operands.push_back({SPV_OPERAND_TYPE_ID, {live_lab_id}});
- terminator->SetInOperands(move(new_operands));
+ terminator->SetInOperands(std::move(new_operands));
context()->UpdateDefUse(terminator);
} else {
// Check if the merge instruction is still needed because of a
diff --git a/source/opt/dead_branch_elim_pass.h b/source/opt/dead_branch_elim_pass.h
index 7841bc47..198bad2d 100644
--- a/source/opt/dead_branch_elim_pass.h
+++ b/source/opt/dead_branch_elim_pass.h
@@ -98,7 +98,7 @@ class DeadBranchElimPass : public MemPass {
// Fix phis in reachable blocks so that only live (or unremovable) incoming
// edges are present. If the block now only has a single live incoming edge,
// remove the phi and replace its uses with its data input. If the single
- // remaining incoming edge is from the phi itself, the the phi is in an
+ // remaining incoming edge is from the phi itself, the phi is in an
// unreachable single block loop. Either the block is dead and will be
// removed, or it's reachable from an unreachable continue target. In the
// latter case that continue target block will be collapsed into a block that
@@ -158,7 +158,7 @@ class DeadBranchElimPass : public MemPass {
uint32_t cont_id, uint32_t header_id, uint32_t merge_id,
std::unordered_set<BasicBlock*>* blocks_with_back_edges);
- // Returns true if there is a brach to the merge node of the selection
+ // Returns true if there is a branch to the merge node of the selection
// construct |switch_header_id| that is inside a nested selection construct or
// in the header of the nested selection construct.
bool SwitchHasNestedBreak(uint32_t switch_header_id);
diff --git a/source/opt/debug_info_manager.cpp b/source/opt/debug_info_manager.cpp
index 060e0d93..c1df6258 100644
--- a/source/opt/debug_info_manager.cpp
+++ b/source/opt/debug_info_manager.cpp
@@ -149,7 +149,7 @@ void DebugInfoManager::RegisterDbgDeclare(uint32_t var_id,
// Create new constant directly into global value area, bypassing the
// Constant manager. This is used when the DefUse or Constant managers
// are invalid and cannot be regenerated due to the module being in an
-// inconsistant state e.g. in the middle of significant modification
+// inconsistent state e.g. in the middle of significant modification
// such as inlining. Invalidate Constant and DefUse managers if used.
uint32_t AddNewConstInGlobals(IRContext* context, uint32_t const_value) {
uint32_t id = context->TakeNextId();
diff --git a/source/opt/def_use_manager.cpp b/source/opt/def_use_manager.cpp
index 394b9fa1..e1e441e0 100644
--- a/source/opt/def_use_manager.cpp
+++ b/source/opt/def_use_manager.cpp
@@ -13,16 +13,23 @@
// limitations under the License.
#include "source/opt/def_use_manager.h"
-
-#include <iostream>
-
-#include "source/opt/log.h"
-#include "source/opt/reflect.h"
+#include "source/util/make_unique.h"
namespace spvtools {
namespace opt {
namespace analysis {
+// Don't compact before we have a reasonable number of ids allocated (~32kb).
+static const size_t kCompactThresholdMinTotalIds = (8 * 1024);
+// Compact when fewer than this fraction of the storage is used (should be 2^n
+// for performance).
+static const size_t kCompactThresholdFractionFreeIds = 8;
+
+DefUseManager::DefUseManager() {
+ use_pool_ = MakeUnique<UseListPool>();
+ used_id_pool_ = MakeUnique<UsedIdListPool>();
+}
+
void DefUseManager::AnalyzeInstDef(Instruction* inst) {
const uint32_t def_id = inst->result_id();
if (def_id != 0) {
@@ -39,15 +46,15 @@ void DefUseManager::AnalyzeInstDef(Instruction* inst) {
}
void DefUseManager::AnalyzeInstUse(Instruction* inst) {
+ // It might have existed before.
+ EraseUseRecordsOfOperandIds(inst);
+
// Create entry for the given instruction. Note that the instruction may
// not have any in-operands. In such cases, we still need a entry for those
// instructions so this manager knows it has seen the instruction later.
- auto* used_ids = &inst_to_used_ids_[inst];
- if (used_ids->size()) {
- EraseUseRecordsOfOperandIds(inst);
- used_ids = &inst_to_used_ids_[inst];
- }
- used_ids->clear(); // It might have existed before.
+ UsedIdList& used_ids =
+ inst_to_used_id_.insert({inst, UsedIdList(used_id_pool_.get())})
+ .first->second;
for (uint32_t i = 0; i < inst->NumOperands(); ++i) {
switch (inst->GetOperand(i).type) {
@@ -58,9 +65,18 @@ void DefUseManager::AnalyzeInstUse(Instruction* inst) {
case SPV_OPERAND_TYPE_SCOPE_ID: {
uint32_t use_id = inst->GetSingleWordOperand(i);
Instruction* def = GetDef(use_id);
- if (!def) assert(false && "Definition is not registered.");
- id_to_users_.insert(UserEntry(def, inst));
- used_ids->push_back(use_id);
+ assert(def && "Definition is not registered.");
+
+ // Add to inst's use records
+ used_ids.push_back(use_id);
+
+ // Add to the users, taking care to avoid adding duplicates. We know
+ // the duplicate for this instruction will always be at the tail.
+ UseList& list = inst_to_users_.insert({def, UseList(use_pool_.get())})
+ .first->second;
+ if (list.empty() || list.back() != inst) {
+ list.push_back(inst);
+ }
} break;
default:
break;
@@ -99,23 +115,6 @@ const Instruction* DefUseManager::GetDef(uint32_t id) const {
return iter->second;
}
-DefUseManager::IdToUsersMap::const_iterator DefUseManager::UsersBegin(
- const Instruction* def) const {
- return id_to_users_.lower_bound(
- UserEntry(const_cast<Instruction*>(def), nullptr));
-}
-
-bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter,
- const IdToUsersMap::const_iterator& cached_end,
- const Instruction* inst) const {
- return (iter != cached_end && iter->first == inst);
-}
-
-bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter,
- const Instruction* inst) const {
- return UsersNotEnd(iter, id_to_users_.end(), inst);
-}
-
bool DefUseManager::WhileEachUser(
const Instruction* def, const std::function<bool(Instruction*)>& f) const {
// Ensure that |def| has been registered.
@@ -123,9 +122,11 @@ bool DefUseManager::WhileEachUser(
"Definition is not registered.");
if (!def->HasResultId()) return true;
- auto end = id_to_users_.end();
- for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) {
- if (!f(iter->second)) return false;
+ auto iter = inst_to_users_.find(def);
+ if (iter != inst_to_users_.end()) {
+ for (Instruction* user : iter->second) {
+ if (!f(user)) return false;
+ }
}
return true;
}
@@ -156,14 +157,15 @@ bool DefUseManager::WhileEachUse(
"Definition is not registered.");
if (!def->HasResultId()) return true;
- auto end = id_to_users_.end();
- for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) {
- Instruction* user = iter->second;
- for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) {
- const Operand& op = user->GetOperand(idx);
- if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) {
- if (def->result_id() == op.words[0]) {
- if (!f(user, idx)) return false;
+ auto iter = inst_to_users_.find(def);
+ if (iter != inst_to_users_.end()) {
+ for (Instruction* user : iter->second) {
+ for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) {
+ const Operand& op = user->GetOperand(idx);
+ if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) {
+ if (def->result_id() == op.words[0]) {
+ if (!f(user, idx)) return false;
+ }
}
}
}
@@ -235,17 +237,18 @@ void DefUseManager::AnalyzeDefUse(Module* module) {
}
void DefUseManager::ClearInst(Instruction* inst) {
- auto iter = inst_to_used_ids_.find(inst);
- if (iter != inst_to_used_ids_.end()) {
+ if (inst_to_used_id_.find(inst) != inst_to_used_id_.end()) {
EraseUseRecordsOfOperandIds(inst);
- if (inst->result_id() != 0) {
- // Remove all uses of this inst.
- auto users_begin = UsersBegin(inst);
- auto end = id_to_users_.end();
- auto new_end = users_begin;
- for (; UsersNotEnd(new_end, end, inst); ++new_end) {
+ uint32_t const result_id = inst->result_id();
+ if (result_id != 0) {
+ // For each using instruction, remove result_id from their used ids.
+ auto iter = inst_to_users_.find(inst);
+ if (iter != inst_to_users_.end()) {
+ for (Instruction* use : iter->second) {
+ inst_to_used_id_.at(use).remove_first(result_id);
+ }
+ inst_to_users_.erase(iter);
}
- id_to_users_.erase(users_begin, new_end);
id_to_def_.erase(inst->result_id());
}
}
@@ -254,59 +257,113 @@ void DefUseManager::ClearInst(Instruction* inst) {
void DefUseManager::EraseUseRecordsOfOperandIds(const Instruction* inst) {
// Go through all ids used by this instruction, remove this instruction's
// uses of them.
- auto iter = inst_to_used_ids_.find(inst);
- if (iter != inst_to_used_ids_.end()) {
- for (auto use_id : iter->second) {
- id_to_users_.erase(
- UserEntry(GetDef(use_id), const_cast<Instruction*>(inst)));
+ auto iter = inst_to_used_id_.find(inst);
+ if (iter != inst_to_used_id_.end()) {
+ const UsedIdList& used_ids = iter->second;
+ for (uint32_t def_id : used_ids) {
+ auto def_iter = inst_to_users_.find(GetDef(def_id));
+ if (def_iter != inst_to_users_.end()) {
+ def_iter->second.remove_first(const_cast<Instruction*>(inst));
+ }
+ }
+ inst_to_used_id_.erase(inst);
+
+ // If we're using only a fraction of the space in used_ids_, compact storage
+ // to prevent memory usage from being unbounded.
+ if (used_id_pool_->total_nodes() > kCompactThresholdMinTotalIds &&
+ used_id_pool_->used_nodes() <
+ used_id_pool_->total_nodes() / kCompactThresholdFractionFreeIds) {
+ CompactStorage();
}
- inst_to_used_ids_.erase(inst);
}
}
-bool operator==(const DefUseManager& lhs, const DefUseManager& rhs) {
+void DefUseManager::CompactStorage() {
+ CompactUseRecords();
+ CompactUsedIds();
+}
+
+void DefUseManager::CompactUseRecords() {
+ std::unique_ptr<UseListPool> new_pool = MakeUnique<UseListPool>();
+ for (auto& iter : inst_to_users_) {
+ iter.second.move_nodes(new_pool.get());
+ }
+ use_pool_ = std::move(new_pool);
+}
+
+void DefUseManager::CompactUsedIds() {
+ std::unique_ptr<UsedIdListPool> new_pool = MakeUnique<UsedIdListPool>();
+ for (auto& iter : inst_to_used_id_) {
+ iter.second.move_nodes(new_pool.get());
+ }
+ used_id_pool_ = std::move(new_pool);
+}
+
+bool CompareAndPrintDifferences(const DefUseManager& lhs,
+ const DefUseManager& rhs) {
+ bool same = true;
+
if (lhs.id_to_def_ != rhs.id_to_def_) {
for (auto p : lhs.id_to_def_) {
if (rhs.id_to_def_.find(p.first) == rhs.id_to_def_.end()) {
- return false;
+ printf("Diff in id_to_def: missing value in rhs\n");
}
}
for (auto p : rhs.id_to_def_) {
if (lhs.id_to_def_.find(p.first) == lhs.id_to_def_.end()) {
- return false;
+ printf("Diff in id_to_def: missing value in lhs\n");
}
}
- return false;
+ same = false;
}
- if (lhs.id_to_users_ != rhs.id_to_users_) {
- for (auto p : lhs.id_to_users_) {
- if (rhs.id_to_users_.count(p) == 0) {
- return false;
- }
+ for (const auto& l : lhs.inst_to_used_id_) {
+ std::set<uint32_t> ul, ur;
+ lhs.ForEachUse(l.first,
+ [&ul](Instruction*, uint32_t id) { ul.insert(id); });
+ rhs.ForEachUse(l.first,
+ [&ur](Instruction*, uint32_t id) { ur.insert(id); });
+ if (ul.size() != ur.size()) {
+ printf(
+ "Diff in inst_to_used_id_: different number of used ids (%zu != %zu)",
+ ul.size(), ur.size());
+ same = false;
+ } else if (ul != ur) {
+ printf("Diff in inst_to_used_id_: different used ids\n");
+ same = false;
}
- for (auto p : rhs.id_to_users_) {
- if (lhs.id_to_users_.count(p) == 0) {
- return false;
- }
+ }
+ for (const auto& r : rhs.inst_to_used_id_) {
+ auto iter_l = lhs.inst_to_used_id_.find(r.first);
+ if (r.second.empty() &&
+ !(iter_l == lhs.inst_to_used_id_.end() || iter_l->second.empty())) {
+ printf("Diff in inst_to_used_id_: unexpected instr in rhs\n");
+ same = false;
}
- return false;
}
- if (lhs.inst_to_used_ids_ != rhs.inst_to_used_ids_) {
- for (auto p : lhs.inst_to_used_ids_) {
- if (rhs.inst_to_used_ids_.count(p.first) == 0) {
- return false;
- }
+ for (const auto& l : lhs.inst_to_users_) {
+ std::set<Instruction*> ul, ur;
+ lhs.ForEachUser(l.first, [&ul](Instruction* use) { ul.insert(use); });
+ rhs.ForEachUser(l.first, [&ur](Instruction* use) { ur.insert(use); });
+ if (ul.size() != ur.size()) {
+ printf("Diff in inst_to_users_: different number of users (%zu != %zu)",
+ ul.size(), ur.size());
+ same = false;
+ } else if (ul != ur) {
+ printf("Diff in inst_to_users_: different users\n");
+ same = false;
}
- for (auto p : rhs.inst_to_used_ids_) {
- if (lhs.inst_to_used_ids_.count(p.first) == 0) {
- return false;
- }
+ }
+ for (const auto& r : rhs.inst_to_users_) {
+ auto iter_l = lhs.inst_to_users_.find(r.first);
+ if (r.second.empty() &&
+ !(iter_l == lhs.inst_to_users_.end() || iter_l->second.empty())) {
+ printf("Diff in inst_to_users_: unexpected instr in rhs\n");
+ same = false;
}
- return false;
}
- return true;
+ return same;
}
} // namespace analysis
diff --git a/source/opt/def_use_manager.h b/source/opt/def_use_manager.h
index 0499e82b..cf6cbdf5 100644
--- a/source/opt/def_use_manager.h
+++ b/source/opt/def_use_manager.h
@@ -15,14 +15,13 @@
#ifndef SOURCE_OPT_DEF_USE_MANAGER_H_
#define SOURCE_OPT_DEF_USE_MANAGER_H_
-#include <list>
#include <set>
#include <unordered_map>
-#include <utility>
#include <vector>
#include "source/opt/instruction.h"
#include "source/opt/module.h"
+#include "source/util/pooled_linked_list.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
@@ -51,59 +50,16 @@ inline bool operator<(const Use& lhs, const Use& rhs) {
return lhs.operand_index < rhs.operand_index;
}
-// Definition and user pair.
-//
-// The first element of the pair is the definition.
-// The second element of the pair is the user.
-//
-// Definition should never be null. User can be null, however, such an entry
-// should be used only for searching (e.g. all users of a particular definition)
-// and never stored in a container.
-using UserEntry = std::pair<Instruction*, Instruction*>;
-
-// Orders UserEntry for use in associative containers (i.e. less than ordering).
-//
-// The definition of an UserEntry is treated as the major key and the users as
-// the minor key so that all the users of a particular definition are
-// consecutive in a container.
-//
-// A null user always compares less than a real user. This is done to provide
-// easy values to search for the beginning of the users of a particular
-// definition (i.e. using {def, nullptr}).
-struct UserEntryLess {
- bool operator()(const UserEntry& lhs, const UserEntry& rhs) const {
- // If lhs.first and rhs.first are both null, fall through to checking the
- // second entries.
- if (!lhs.first && rhs.first) return true;
- if (lhs.first && !rhs.first) return false;
-
- // If neither definition is null, then compare unique ids.
- if (lhs.first && rhs.first) {
- if (lhs.first->unique_id() < rhs.first->unique_id()) return true;
- if (rhs.first->unique_id() < lhs.first->unique_id()) return false;
- }
-
- // Return false on equality.
- if (!lhs.second && !rhs.second) return false;
- if (!lhs.second) return true;
- if (!rhs.second) return false;
-
- // If neither user is null then compare unique ids.
- return lhs.second->unique_id() < rhs.second->unique_id();
- }
-};
-
// A class for analyzing and managing defs and uses in an Module.
class DefUseManager {
public:
using IdToDefMap = std::unordered_map<uint32_t, Instruction*>;
- using IdToUsersMap = std::set<UserEntry, UserEntryLess>;
// Constructs a def-use manager from the given |module|. All internal messages
// will be communicated to the outside via the given message |consumer|. This
// instance only keeps a reference to the |consumer|, so the |consumer| should
// outlive this instance.
- DefUseManager(Module* module) { AnalyzeDefUse(module); }
+ DefUseManager(Module* module) : DefUseManager() { AnalyzeDefUse(module); }
DefUseManager(const DefUseManager&) = delete;
DefUseManager(DefUseManager&&) = delete;
@@ -191,14 +147,12 @@ class DefUseManager {
// Returns the annotation instrunctions which are a direct use of the given
// |id|. This means when the decorations are applied through decoration
// group(s), this function will just return the OpGroupDecorate
- // instrcution(s) which refer to the given id as an operand. The OpDecorate
+ // instruction(s) which refer to the given id as an operand. The OpDecorate
// instructions which decorate the decoration group will not be returned.
std::vector<Instruction*> GetAnnotations(uint32_t id) const;
// Returns the map from ids to their def instructions.
const IdToDefMap& id_to_defs() const { return id_to_def_; }
- // Returns the map from instructions to their users.
- const IdToUsersMap& id_to_users() const { return id_to_users_; }
// Clear the internal def-use record of the given instruction |inst|. This
// method will update the use information of the operand ids of |inst|. The
@@ -210,43 +164,43 @@ class DefUseManager {
// Erases the records that a given instruction uses its operand ids.
void EraseUseRecordsOfOperandIds(const Instruction* inst);
- friend bool operator==(const DefUseManager&, const DefUseManager&);
- friend bool operator!=(const DefUseManager& lhs, const DefUseManager& rhs) {
- return !(lhs == rhs);
- }
+ friend bool CompareAndPrintDifferences(const DefUseManager&,
+ const DefUseManager&);
- // If |inst| has not already been analysed, then analyses its defintion and
+ // If |inst| has not already been analysed, then analyses its definition and
// uses.
void UpdateDefUse(Instruction* inst);
+ // Compacts any internal storage to save memory.
+ void CompactStorage();
+
private:
- using InstToUsedIdsMap =
- std::unordered_map<const Instruction*, std::vector<uint32_t>>;
+ using UseList = spvtools::utils::PooledLinkedList<Instruction*>;
+ using UseListPool = spvtools::utils::PooledLinkedListNodes<Instruction*>;
+ // Stores linked lists of Instructions using a def.
+ using InstToUsersMap = std::unordered_map<const Instruction*, UseList>;
- // Returns the first location that {|def|, nullptr} could be inserted into the
- // users map without violating ordering.
- IdToUsersMap::const_iterator UsersBegin(const Instruction* def) const;
+ using UsedIdList = spvtools::utils::PooledLinkedList<uint32_t>;
+ using UsedIdListPool = spvtools::utils::PooledLinkedListNodes<uint32_t>;
+ // Stores mapping from instruction to their UsedIdRange.
+ using InstToUsedIdMap = std::unordered_map<const Instruction*, UsedIdList>;
- // Returns true if |iter| has not reached the end of |def|'s users.
- //
- // In the first version |iter| is compared against the end of the map for
- // validity before other checks. In the second version, |iter| is compared
- // against |cached_end| for validity before other checks. This allows caching
- // the map's end which is a performance improvement on some platforms.
- bool UsersNotEnd(const IdToUsersMap::const_iterator& iter,
- const Instruction* def) const;
- bool UsersNotEnd(const IdToUsersMap::const_iterator& iter,
- const IdToUsersMap::const_iterator& cached_end,
- const Instruction* def) const;
+ DefUseManager();
// Analyzes the defs and uses in the given |module| and populates data
// structures in this class. Does nothing if |module| is nullptr.
void AnalyzeDefUse(Module* module);
- IdToDefMap id_to_def_; // Mapping from ids to their definitions
- IdToUsersMap id_to_users_; // Mapping from ids to their users
- // Mapping from instructions to the ids used in the instruction.
- InstToUsedIdsMap inst_to_used_ids_;
+ // Removes unused entries in used_records_ and used_ids_.
+ void CompactUseRecords();
+ void CompactUsedIds();
+
+ IdToDefMap id_to_def_; // Mapping from ids to their definitions
+ InstToUsersMap inst_to_users_; // Map from def to uses.
+ std::unique_ptr<UseListPool> use_pool_;
+
+ std::unique_ptr<UsedIdListPool> used_id_pool_;
+ InstToUsedIdMap inst_to_used_id_; // Map from instruction to used ids.
};
} // namespace analysis
diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp
index bcbdde94..b130ca80 100644
--- a/source/opt/desc_sroa.cpp
+++ b/source/opt/desc_sroa.cpp
@@ -118,7 +118,7 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
if (use->NumInOperands() == 2) {
// We are not indexing into the replacement variable. We can replaces the
- // access chain with the replacement varibale itself.
+ // access chain with the replacement variable itself.
context()->ReplaceAllUsesWith(use->result_id(), replacement_var);
context()->KillInst(use);
return true;
@@ -135,8 +135,8 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
// Use the replacement variable as the base address.
new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}});
- // Drop the first index because it is consumed by the replacment, and copy the
- // rest.
+ // Drop the first index because it is consumed by the replacement, and copy
+ // the rest.
for (uint32_t i = 4; i < use->NumOperands(); i++) {
new_operands.emplace_back(use->GetOperand(i));
}
@@ -169,7 +169,7 @@ void DescriptorScalarReplacement::CopyDecorationsForNewVariable(
Instruction* old_var, uint32_t index, uint32_t new_var_id,
uint32_t new_var_ptr_type_id, const bool is_old_var_array,
const bool is_old_var_struct, Instruction* old_var_type) {
- // Handle OpDecorate instructions.
+ // Handle OpDecorate and OpDecorateString instructions.
for (auto old_decoration :
get_decoration_mgr()->GetDecorationsFor(old_var->result_id(), true)) {
uint32_t new_binding = 0;
@@ -212,7 +212,8 @@ uint32_t DescriptorScalarReplacement::GetNewBindingForElement(
void DescriptorScalarReplacement::CreateNewDecorationForNewVariable(
Instruction* old_decoration, uint32_t new_var_id, uint32_t new_binding) {
- assert(old_decoration->opcode() == SpvOpDecorate);
+ assert(old_decoration->opcode() == SpvOpDecorate ||
+ old_decoration->opcode() == SpvOpDecorateString);
std::unique_ptr<Instruction> new_decoration(old_decoration->Clone(context()));
new_decoration->SetInOperand(0, {new_var_id});
diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h
index fea06255..6a24fd87 100644
--- a/source/opt/desc_sroa.h
+++ b/source/opt/desc_sroa.h
@@ -115,10 +115,11 @@ class DescriptorScalarReplacement : public Pass {
const bool is_old_var_struct,
Instruction* old_var_type);
- // Create a new OpDecorate instruction by cloning |old_decoration|. The new
- // OpDecorate instruction will be used for a variable whose id is
- // |new_var_ptr_type_id|. If |old_decoration| is a decoration for a binding,
- // the new OpDecorate instruction will have |new_binding| as its binding.
+ // Create a new OpDecorate(String) instruction by cloning |old_decoration|.
+ // The new OpDecorate(String) instruction will be used for a variable whose id
+ // is |new_var_ptr_type_id|. If |old_decoration| is a decoration for a
+ // binding, the new OpDecorate(String) instruction will have |new_binding| as
+ // its binding.
void CreateNewDecorationForNewVariable(Instruction* old_decoration,
uint32_t new_var_id,
uint32_t new_binding);
@@ -131,7 +132,7 @@ class DescriptorScalarReplacement : public Pass {
// A map from an OpVariable instruction to the set of variables that will be
// used to replace it. The entry |replacement_variables_[var][i]| is the id of
- // a variable that will be used in the place of the the ith element of the
+ // a variable that will be used in the place of the ith element of the
// array |var|. If the entry is |0|, then the variable has not been
// created yet.
std::map<Instruction*, std::vector<uint32_t>> replacement_variables_;
diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp
index 55287f44..d86de151 100644
--- a/source/opt/dominator_tree.cpp
+++ b/source/opt/dominator_tree.cpp
@@ -48,7 +48,7 @@ namespace {
// BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode
// SuccessorLambda - Lamdba matching the signature of 'const
// std::vector<BBType>*(const BBType *A)'. Will return a vector of the nodes
-// succeding BasicBlock A.
+// succeeding BasicBlock A.
// PostLambda - Lamdba matching the signature of 'void (const BBType*)' will be
// called on each node traversed AFTER their children.
// PreLambda - Lamdba matching the signature of 'void (const BBType*)' will be
@@ -69,7 +69,7 @@ static void DepthFirstSearch(const BBType* bb, SuccessorLambda successors,
// BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode
// SuccessorLambda - Lamdba matching the signature of 'const
// std::vector<BBType>*(const BBType *A)'. Will return a vector of the nodes
-// succeding BasicBlock A.
+// succeeding BasicBlock A.
// PostLambda - Lamdba matching the signature of 'void (const BBType*)' will be
// called on each node traversed after their children.
template <typename BBType, typename SuccessorLambda, typename PostLambda>
diff --git a/source/opt/dominator_tree.h b/source/opt/dominator_tree.h
index 0024bc50..1674b228 100644
--- a/source/opt/dominator_tree.h
+++ b/source/opt/dominator_tree.h
@@ -278,7 +278,7 @@ class DominatorTree {
private:
// Wrapper function which gets the list of pairs of each BasicBlocks to its
- // immediately dominating BasicBlock and stores the result in the the edges
+ // immediately dominating BasicBlock and stores the result in the edges
// parameter.
//
// The |edges| vector will contain the dominator tree as pairs of nodes.
diff --git a/source/opt/eliminate_dead_input_components_pass.cpp b/source/opt/eliminate_dead_input_components_pass.cpp
new file mode 100644
index 00000000..f383136d
--- /dev/null
+++ b/source/opt/eliminate_dead_input_components_pass.cpp
@@ -0,0 +1,146 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+// Copyright (c) 2022 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/eliminate_dead_input_components_pass.h"
+
+#include <set>
+#include <vector>
+
+#include "source/opt/instruction.h"
+#include "source/opt/ir_builder.h"
+#include "source/opt/ir_context.h"
+#include "source/util/bit_vector.h"
+
+namespace {
+
+const uint32_t kAccessChainBaseInIdx = 0;
+const uint32_t kAccessChainIndex0InIdx = 1;
+const uint32_t kConstantValueInIdx = 0;
+const uint32_t kVariableStorageClassInIdx = 0;
+
+} // namespace
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status EliminateDeadInputComponentsPass::Process() {
+ // Current functionality assumes shader capability
+ if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader))
+ return Status::SuccessWithoutChange;
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ bool modified = false;
+ std::vector<std::pair<Instruction*, unsigned>> arrays_to_change;
+ for (auto& var : context()->types_values()) {
+ if (var.opcode() != SpvOpVariable) {
+ continue;
+ }
+ analysis::Type* var_type = type_mgr->GetType(var.type_id());
+ analysis::Pointer* ptr_type = var_type->AsPointer();
+ if (ptr_type == nullptr) {
+ continue;
+ }
+ if (ptr_type->storage_class() != SpvStorageClassInput) {
+ continue;
+ }
+ const analysis::Array* arr_type = ptr_type->pointee_type()->AsArray();
+ if (arr_type == nullptr) {
+ continue;
+ }
+ unsigned arr_len_id = arr_type->LengthId();
+ Instruction* arr_len_inst = def_use_mgr->GetDef(arr_len_id);
+ if (arr_len_inst->opcode() != SpvOpConstant) {
+ continue;
+ }
+ // SPIR-V requires array size is >= 1, so this works for signed or
+ // unsigned size
+ unsigned original_max =
+ arr_len_inst->GetSingleWordInOperand(kConstantValueInIdx) - 1;
+ unsigned max_idx = FindMaxIndex(var, original_max);
+ if (max_idx != original_max) {
+ ChangeArrayLength(var, max_idx + 1);
+ modified = true;
+ }
+ }
+
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+unsigned EliminateDeadInputComponentsPass::FindMaxIndex(Instruction& var,
+ unsigned original_max) {
+ unsigned max = 0;
+ bool seen_non_const_ac = false;
+ assert(var.opcode() == SpvOpVariable && "must be variable");
+ context()->get_def_use_mgr()->WhileEachUser(
+ var.result_id(), [&max, &seen_non_const_ac, var, this](Instruction* use) {
+ auto use_opcode = use->opcode();
+ if (use_opcode == SpvOpLoad || use_opcode == SpvOpCopyMemory ||
+ use_opcode == SpvOpCopyMemorySized ||
+ use_opcode == SpvOpCopyObject) {
+ seen_non_const_ac = true;
+ return false;
+ }
+ if (use->opcode() != SpvOpAccessChain &&
+ use->opcode() != SpvOpInBoundsAccessChain) {
+ return true;
+ }
+ // OpAccessChain with no indices currently not optimized
+ if (use->NumInOperands() == 1) {
+ seen_non_const_ac = true;
+ return false;
+ }
+ unsigned base_id = use->GetSingleWordInOperand(kAccessChainBaseInIdx);
+ USE_ASSERT(base_id == var.result_id() && "unexpected base");
+ unsigned idx_id = use->GetSingleWordInOperand(kAccessChainIndex0InIdx);
+ Instruction* idx_inst = context()->get_def_use_mgr()->GetDef(idx_id);
+ if (idx_inst->opcode() != SpvOpConstant) {
+ seen_non_const_ac = true;
+ return false;
+ }
+ unsigned value = idx_inst->GetSingleWordInOperand(kConstantValueInIdx);
+ if (value > max) max = value;
+ return true;
+ });
+ return seen_non_const_ac ? original_max : max;
+}
+
+void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr,
+ unsigned length) {
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
+ analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
+ analysis::Pointer* ptr_type = type_mgr->GetType(arr.type_id())->AsPointer();
+ const analysis::Array* arr_ty = ptr_type->pointee_type()->AsArray();
+ assert(arr_ty && "expecting array type");
+ uint32_t length_id = const_mgr->GetUIntConst(length);
+ analysis::Array new_arr_ty(arr_ty->element_type(),
+ arr_ty->GetConstantLengthInfo(length_id, length));
+ analysis::Type* reg_new_arr_ty = type_mgr->GetRegisteredType(&new_arr_ty);
+ analysis::Pointer new_ptr_ty(reg_new_arr_ty, SpvStorageClassInput);
+ analysis::Type* reg_new_ptr_ty = type_mgr->GetRegisteredType(&new_ptr_ty);
+ uint32_t new_ptr_ty_id = type_mgr->GetTypeInstruction(reg_new_ptr_ty);
+ arr.SetResultType(new_ptr_ty_id);
+ def_use_mgr->AnalyzeInstUse(&arr);
+ // Move array OpVariable instruction after its new type to preserve order
+ USE_ASSERT(arr.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
+ SpvStorageClassFunction &&
+ "cannot move Function variable");
+ Instruction* new_ptr_ty_inst = def_use_mgr->GetDef(new_ptr_ty_id);
+ arr.RemoveFromList();
+ arr.InsertAfter(new_ptr_ty_inst);
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/eliminate_dead_input_components_pass.h b/source/opt/eliminate_dead_input_components_pass.h
new file mode 100644
index 00000000..b77857f4
--- /dev/null
+++ b/source/opt/eliminate_dead_input_components_pass.h
@@ -0,0 +1,59 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+// Copyright (c) 2022 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_ELIMINATE_DEAD_INPUT_COMPONENTS_H_
+#define SOURCE_OPT_ELIMINATE_DEAD_INPUT_COMPONENTS_H_
+
+#include <unordered_map>
+
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class EliminateDeadInputComponentsPass : public Pass {
+ public:
+ explicit EliminateDeadInputComponentsPass() {}
+
+ const char* name() const override { return "reduce-load-size"; }
+ Status Process() override;
+
+ // Return the mask of preserved Analyses.
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse |
+ IRContext::kAnalysisInstrToBlockMapping |
+ IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG |
+ IRContext::kAnalysisDominatorAnalysis |
+ IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
+ IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
+ }
+
+ private:
+ // Find the max constant used to index the variable declared by |var|
+ // through OpAccessChain or OpInBoundsAccessChain. If any non-constant
+ // indices or non-Op*AccessChain use of |var|, return |original_max|.
+ unsigned FindMaxIndex(Instruction& var, unsigned original_max);
+
+ // Change the length of the array |inst| to |length|
+ void ChangeArrayLength(Instruction& inst, unsigned length);
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_ELIMINATE_DEAD_INPUT_COMPONENTS_H_
diff --git a/source/opt/eliminate_dead_members_pass.cpp b/source/opt/eliminate_dead_members_pass.cpp
index a24ba8f4..52aca525 100644
--- a/source/opt/eliminate_dead_members_pass.cpp
+++ b/source/opt/eliminate_dead_members_pass.cpp
@@ -38,7 +38,7 @@ Pass::Status EliminateDeadMembersPass::Process() {
}
void EliminateDeadMembersPass::FindLiveMembers() {
- // Until we have implemented the rewritting of OpSpecConsantOp instructions,
+ // Until we have implemented the rewriting of OpSpecConsantOp instructions,
// we have to mark them as fully used just to be safe.
for (auto& inst : get_module()->types_values()) {
if (inst.opcode() == SpvOpSpecConstantOp) {
@@ -570,7 +570,7 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) {
Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
switch (type_inst->opcode()) {
case SpvOpTypeStruct:
- // The type will have already been rewriten, so use the new member
+ // The type will have already been rewritten, so use the new member
// index.
type_id = type_inst->GetSingleWordInOperand(new_member_idx);
break;
diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp
index 39a4a348..a5902716 100644
--- a/source/opt/feature_manager.cpp
+++ b/source/opt/feature_manager.cpp
@@ -39,8 +39,7 @@ void FeatureManager::AddExtension(Instruction* ext) {
assert(ext->opcode() == SpvOpExtension &&
"Expecting an extension instruction.");
- const std::string name =
- reinterpret_cast<const char*>(ext->GetInOperand(0u).words.data());
+ const std::string name = ext->GetInOperand(0u).AsString();
Extension extension;
if (GetExtensionFromString(name.c_str(), &extension)) {
extensions_.Add(extension);
diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp
index 6550fb4f..b903da6a 100644
--- a/source/opt/fold.cpp
+++ b/source/opt/fold.cpp
@@ -540,7 +540,7 @@ std::vector<uint32_t> InstructionFolder::FoldVectors(
// in 32-bit words here. The reason of not using FoldScalars() here
// is that we do not create temporary null constants as components
// when the vector operand is a NullConstant because Constant creation
- // may need extra checks for the validity and that is not manageed in
+ // may need extra checks for the validity and that is not managed in
// here.
if (const analysis::ScalarConstant* scalar_component =
vector_operand->GetComponents().at(d)->AsScalarConstant()) {
diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp
index 8ab717ea..85f11fde 100644
--- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp
+++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp
@@ -115,7 +115,7 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
Instruction* folded_inst = nullptr;
assert(inst->GetInOperand(0).type ==
SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
- "The first in-operand of OpSpecContantOp instruction must be of "
+ "The first in-operand of OpSpecConstantOp instruction must be of "
"SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index 4904f186..c879a0c5 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -2368,7 +2368,7 @@ FoldingRule VectorShuffleFeedingShuffle() {
// fold.
return false;
}
- } else {
+ } else if (component_index != undef_literal) {
if (new_feeder_id == 0) {
// First time through, save the id of the operand the element comes
// from.
@@ -2382,7 +2382,7 @@ FoldingRule VectorShuffleFeedingShuffle() {
component_index -= feeder_op0_length;
}
- if (!feeder_is_op0) {
+ if (!feeder_is_op0 && component_index != undef_literal) {
component_index += op0_length;
}
}
diff --git a/source/opt/graphics_robust_access_pass.cpp b/source/opt/graphics_robust_access_pass.cpp
index 1b28f9b5..4652d72d 100644
--- a/source/opt/graphics_robust_access_pass.cpp
+++ b/source/opt/graphics_robust_access_pass.cpp
@@ -13,7 +13,7 @@
// limitations under the License.
// This pass injects code in a graphics shader to implement guarantees
-// satisfying Vulkan's robustBufferAcces rules. Robust access rules permit
+// satisfying Vulkan's robustBufferAccess rules. Robust access rules permit
// an out-of-bounds access to be redirected to an access of the same type
// (load, store, etc.) but within the same root object.
//
@@ -74,7 +74,7 @@
// Pointers are always (correctly) typed and so the address and number of
// consecutive locations are fully determined by the pointer.
//
-// - A pointer value orginates as one of few cases:
+// - A pointer value originates as one of few cases:
//
// - OpVariable for an interface object or an array of them: image,
// buffer (UBO or SSBO), sampler, sampled-image, push-constant, input
@@ -559,21 +559,17 @@ uint32_t GraphicsRobustAccessPass::GetGlslInsts() {
if (module_status_.glsl_insts_id == 0) {
// This string serves double-duty as raw data for a string and for a vector
// of 32-bit words
- const char glsl[] = "GLSL.std.450\0\0\0\0";
- const size_t glsl_str_byte_len = 16;
+ const char glsl[] = "GLSL.std.450";
// Use an existing import if we can.
for (auto& inst : context()->module()->ext_inst_imports()) {
- const auto& name_words = inst.GetInOperand(0).words;
- if (0 == std::strncmp(reinterpret_cast<const char*>(name_words.data()),
- glsl, glsl_str_byte_len)) {
+ if (inst.GetInOperand(0).AsString() == glsl) {
module_status_.glsl_insts_id = inst.result_id();
}
}
if (module_status_.glsl_insts_id == 0) {
// Make a new import instruction.
module_status_.glsl_insts_id = TakeNextId();
- std::vector<uint32_t> words(glsl_str_byte_len / sizeof(uint32_t));
- std::memcpy(words.data(), glsl, glsl_str_byte_len);
+ std::vector<uint32_t> words = spvtools::utils::MakeVector(glsl);
auto import_inst = MakeUnique<Instruction>(
context(), SpvOpExtInstImport, 0, module_status_.glsl_insts_id,
std::initializer_list<Operand>{
@@ -962,7 +958,7 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
constant_mgr->GetDefiningInstruction(component_0)->result_id();
// If the image is a cube array, then the last component of the queried
- // size is the layer count. In the query, we have to accomodate folding
+ // size is the layer count. In the query, we have to accommodate folding
// in the face index ranging from 0 through 5. The inclusive upper bound
// on the third coordinate therefore is multiplied by 6.
auto* query_size_including_faces = query_size;
diff --git a/source/opt/graphics_robust_access_pass.h b/source/opt/graphics_robust_access_pass.h
index 6fc692c1..8f4c9dc7 100644
--- a/source/opt/graphics_robust_access_pass.h
+++ b/source/opt/graphics_robust_access_pass.h
@@ -111,7 +111,7 @@ class GraphicsRobustAccessPass : public Pass {
Instruction* max, Instruction* where);
// Returns a new instruction which evaluates to the length the runtime array
- // referenced by the access chain at the specfied index. The instruction is
+ // referenced by the access chain at the specified index. The instruction is
// inserted before the access chain instruction. Returns a null pointer in
// some cases if assumptions are violated (rather than asserting out).
opt::Instruction* MakeRuntimeArrayLengthInst(Instruction* access_chain,
diff --git a/source/opt/inst_bindless_check_pass.cpp b/source/opt/inst_bindless_check_pass.cpp
index 5607239a..c2c5d6cb 100644
--- a/source/opt/inst_bindless_check_pass.cpp
+++ b/source/opt/inst_bindless_check_pass.cpp
@@ -39,13 +39,6 @@ static const int kSpvTypeImageMS = 4;
static const int kSpvTypeImageSampled = 5;
} // anonymous namespace
-// Avoid unused variable warning/error on Linux
-#ifndef NDEBUG
-#define USE_ASSERT(x) assert(x)
-#else
-#define USE_ASSERT(x) ((void)(x))
-#endif
-
namespace spvtools {
namespace opt {
diff --git a/source/opt/inst_bindless_check_pass.h b/source/opt/inst_bindless_check_pass.h
index cd961805..e6e6ef4f 100644
--- a/source/opt/inst_bindless_check_pass.h
+++ b/source/opt/inst_bindless_check_pass.h
@@ -147,11 +147,11 @@ class InstBindlessCheckPass : public InstrumentPass {
uint32_t GenLastByteIdx(RefAnalysis* ref, InstructionBuilder* builder);
// Clone original image computation starting at |image_id| into |builder|.
- // This may generate more than one instruction if neccessary.
+ // This may generate more than one instruction if necessary.
uint32_t CloneOriginalImage(uint32_t image_id, InstructionBuilder* builder);
// Clone original original reference encapsulated by |ref| into |builder|.
- // This may generate more than one instruction if neccessary.
+ // This may generate more than one instruction if necessary.
uint32_t CloneOriginalReference(RefAnalysis* ref,
InstructionBuilder* builder);
diff --git a/source/opt/inst_debug_printf_pass.cpp b/source/opt/inst_debug_printf_pass.cpp
index c0e6bc3f..4218138f 100644
--- a/source/opt/inst_debug_printf_pass.cpp
+++ b/source/opt/inst_debug_printf_pass.cpp
@@ -16,6 +16,7 @@
#include "inst_debug_printf_pass.h"
+#include "source/util/string_utils.h"
#include "spirv/unified1/NonSemanticDebugPrintf.h"
namespace spvtools {
@@ -231,10 +232,8 @@ Pass::Status InstDebugPrintfPass::ProcessImpl() {
bool non_sem_set_seen = false;
for (auto c_itr = context()->module()->ext_inst_import_begin();
c_itr != context()->module()->ext_inst_import_end(); ++c_itr) {
- const char* set_name =
- reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]);
- const char* non_sem_str = "NonSemantic.";
- if (!strncmp(set_name, non_sem_str, strlen(non_sem_str))) {
+ const std::string set_name = c_itr->GetInOperand(0).AsString();
+ if (spvtools::utils::starts_with(set_name, "NonSemantic.")) {
non_sem_set_seen = true;
break;
}
@@ -242,9 +241,8 @@ Pass::Status InstDebugPrintfPass::ProcessImpl() {
if (!non_sem_set_seen) {
for (auto c_itr = context()->module()->extension_begin();
c_itr != context()->module()->extension_end(); ++c_itr) {
- const char* ext_name =
- reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]);
- if (!strcmp(ext_name, "SPV_KHR_non_semantic_info")) {
+ const std::string ext_name = c_itr->GetInOperand(0).AsString();
+ if (ext_name == "SPV_KHR_non_semantic_info") {
context()->KillInst(&*c_itr);
break;
}
diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp
index 2461e41e..418f1213 100644
--- a/source/opt/instruction.cpp
+++ b/source/opt/instruction.cpp
@@ -76,10 +76,9 @@ Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst,
dbg_scope_(kNoDebugScope, kNoInlinedAt) {
for (uint32_t i = 0; i < inst.num_operands; ++i) {
const auto& current_payload = inst.operands[i];
- std::vector<uint32_t> words(
- inst.words + current_payload.offset,
+ operands_.emplace_back(
+ current_payload.type, inst.words + current_payload.offset,
inst.words + current_payload.offset + current_payload.num_words);
- operands_.emplace_back(current_payload.type, std::move(words));
}
assert((!IsLineInst() || dbg_line.empty()) &&
"Op(No)Line attaching to Op(No)Line found");
@@ -96,10 +95,9 @@ Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst,
dbg_scope_(dbg_scope) {
for (uint32_t i = 0; i < inst.num_operands; ++i) {
const auto& current_payload = inst.operands[i];
- std::vector<uint32_t> words(
- inst.words + current_payload.offset,
+ operands_.emplace_back(
+ current_payload.type, inst.words + current_payload.offset,
inst.words + current_payload.offset + current_payload.num_words);
- operands_.emplace_back(current_payload.type, std::move(words));
}
}
diff --git a/source/opt/instruction.h b/source/opt/instruction.h
index ce568f66..2163d99b 100644
--- a/source/opt/instruction.h
+++ b/source/opt/instruction.h
@@ -24,6 +24,7 @@
#include "NonSemanticShaderDebugInfo100.h"
#include "OpenCLDebugInfo100.h"
+#include "source/binary.h"
#include "source/common_debug_info.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_spirv_header.h"
@@ -32,6 +33,7 @@
#include "source/opt/reflect.h"
#include "source/util/ilist_node.h"
#include "source/util/small_vector.h"
+#include "source/util/string_utils.h"
#include "spirv-tools/libspirv.h"
const uint32_t kNoDebugScope = 0;
@@ -82,21 +84,32 @@ struct Operand {
Operand(spv_operand_type_t t, const OperandData& w) : type(t), words(w) {}
+ template <class InputIt>
+ Operand(spv_operand_type_t t, InputIt firstOperandData,
+ InputIt lastOperandData)
+ : type(t), words(firstOperandData, lastOperandData) {}
+
spv_operand_type_t type; // Type of this logical operand.
OperandData words; // Binary segments of this logical operand.
- // Returns a string operand as a C-style string.
- const char* AsCString() const {
- assert(type == SPV_OPERAND_TYPE_LITERAL_STRING);
- return reinterpret_cast<const char*>(words.data());
+ uint32_t AsId() const {
+ assert(spvIsIdType(type));
+ assert(words.size() == 1);
+ return words[0];
}
// Returns a string operand as a std::string.
- std::string AsString() const { return AsCString(); }
+ std::string AsString() const {
+ assert(type == SPV_OPERAND_TYPE_LITERAL_STRING);
+ return spvtools::utils::MakeString(words);
+ }
// Returns a literal integer operand as a uint64_t
uint64_t AsLiteralUint64() const {
- assert(type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER);
+ assert(type == SPV_OPERAND_TYPE_LITERAL_INTEGER ||
+ type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER ||
+ type == SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER ||
+ type == SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER);
assert(1 <= words.size());
assert(words.size() <= 2);
uint64_t result = 0;
@@ -123,7 +136,7 @@ inline bool operator!=(const Operand& o1, const Operand& o2) {
}
// This structure is used to represent a DebugScope instruction from
-// the OpenCL.100.DebugInfo extened instruction set. Note that we can
+// the OpenCL.100.DebugInfo extended instruction set. Note that we can
// ignore the result id of DebugScope instruction because it is not
// used for anything. We do not keep it to reduce the size of
// structure.
@@ -295,6 +308,7 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> {
inline void SetInOperands(OperandList&& new_operands);
// Sets the result type id.
inline void SetResultType(uint32_t ty_id);
+ inline bool HasResultType() const { return has_type_id_; }
// Sets the result id
inline void SetResultId(uint32_t res_id);
inline bool HasResultId() const { return has_result_id_; }
diff --git a/source/opt/instrument_pass.h b/source/opt/instrument_pass.h
index 12b939d4..90c1dd47 100644
--- a/source/opt/instrument_pass.h
+++ b/source/opt/instrument_pass.h
@@ -50,7 +50,7 @@
// A validation pass may read or write multiple buffers. All such buffers
// are located in a single debug descriptor set whose index is passed at the
// creation of the instrumentation pass. The bindings of the buffers used by
-// a validation pass are permanantly assigned and fixed and documented by
+// a validation pass are permanently assigned and fixed and documented by
// the kDebugOutput* static consts.
namespace spvtools {
@@ -179,8 +179,8 @@ class InstrumentPass : public Pass {
// the error. Every stage will write a fixed number of words. Vertex shaders
// will write the Vertex and Instance ID. Fragment shaders will write
// FragCoord.xy. Compute shaders will write the GlobalInvocation ID.
- // The tesselation eval shader will write the Primitive ID and TessCoords.uv.
- // The tesselation control shader and geometry shader will write the
+ // The tessellation eval shader will write the Primitive ID and TessCoords.uv.
+ // The tessellation control shader and geometry shader will write the
// Primitive ID and Invocation ID.
//
// The Validation Error Code specifies the exact error which has occurred.
diff --git a/source/opt/interp_fixup_pass.cpp b/source/opt/interp_fixup_pass.cpp
index ad29e6a7..e8cdd99f 100644
--- a/source/opt/interp_fixup_pass.cpp
+++ b/source/opt/interp_fixup_pass.cpp
@@ -31,13 +31,6 @@ namespace {
// Input Operand Indices
static const int kSpvVariableStorageClassInIdx = 0;
-// Avoid unused variable warning/error on Linux
-#ifndef NDEBUG
-#define USE_ASSERT(x) assert(x)
-#else
-#define USE_ASSERT(x) ((void)(x))
-#endif
-
// Folding rule function which attempts to replace |op(OpLoad(a),...)|
// by |op(a,...)|, where |op| is one of the GLSLstd450 InterpolateAt*
// instructions. Returns true if replaced, false otherwise.
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index 612a831a..a80d4f2d 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -41,6 +41,8 @@ namespace spvtools {
namespace opt {
void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) {
+ set = Analysis(set & ~valid_analyses_);
+
if (set & kAnalysisDefUse) {
BuildDefUseManager();
}
@@ -106,7 +108,7 @@ void IRContext::InvalidateAnalyses(IRContext::Analysis analyses_to_invalidate) {
analyses_to_invalidate |= kAnalysisDebugInfo;
}
- // The dominator analysis hold the psuedo entry and exit nodes from the CFG.
+ // The dominator analysis hold the pseudo entry and exit nodes from the CFG.
// Also if the CFG change the dominators many changed as well, so the
// dominator analysis should be invalidated as well.
if (analyses_to_invalidate & kAnalysisCFG) {
@@ -317,7 +319,7 @@ bool IRContext::IsConsistent() {
#else
if (AreAnalysesValid(kAnalysisDefUse)) {
analysis::DefUseManager new_def_use(module());
- if (*get_def_use_mgr() != new_def_use) {
+ if (!CompareAndPrintDifferences(*get_def_use_mgr(), new_def_use)) {
return false;
}
}
@@ -623,9 +625,8 @@ void IRContext::AddCombinatorsForCapability(uint32_t capability) {
void IRContext::AddCombinatorsForExtension(Instruction* extension) {
assert(extension->opcode() == SpvOpExtInstImport &&
"Expecting an import of an extension's instruction set.");
- const char* extension_name =
- reinterpret_cast<const char*>(&extension->GetInOperand(0).words[0]);
- if (!strcmp(extension_name, "GLSL.std.450")) {
+ const std::string extension_name = extension->GetInOperand(0).AsString();
+ if (extension_name == "GLSL.std.450") {
combinator_ops_[extension->result_id()] = {GLSLstd450Round,
GLSLstd450RoundEven,
GLSLstd450Trunc,
@@ -944,11 +945,11 @@ void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
uint32_t line_number = 0;
uint32_t col_number = 0;
- char* source = nullptr;
+ std::string source;
if (line_inst != nullptr) {
Instruction* file_name =
get_def_use_mgr()->GetDef(line_inst->GetSingleWordInOperand(0));
- source = reinterpret_cast<char*>(&file_name->GetInOperand(0).words[0]);
+ source = file_name->GetInOperand(0).AsString();
// Get the line number and column number.
line_number = line_inst->GetSingleWordInOperand(1);
@@ -957,7 +958,7 @@ void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
message +=
"\n " + inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
- consumer()(SPV_MSG_ERROR, source, {line_number, col_number, 0},
+ consumer()(SPV_MSG_ERROR, source.c_str(), {line_number, col_number, 0},
message.c_str());
}
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 65853476..946f9e9d 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -43,6 +43,7 @@
#include "source/opt/type_manager.h"
#include "source/opt/value_number_table.h"
#include "source/util/make_unique.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -301,7 +302,7 @@ class IRContext {
}
}
- // Returns a pointer the decoration manager. If the decoration manger is
+ // Returns a pointer the decoration manager. If the decoration manager is
// invalid, it is rebuilt first.
analysis::DecorationManager* get_decoration_mgr() {
if (!AreAnalysesValid(kAnalysisDecorations)) {
@@ -384,7 +385,7 @@ class IRContext {
// Deletes the instruction defining the given |id|. Returns true on
// success, false if the given |id| is not defined at all. This method also
- // erases the name, decorations, and defintion of |id|.
+ // erases the name, decorations, and definition of |id|.
//
// Pointers and iterators pointing to the deleted instructions become invalid.
// However other pointers and iterators are still valid.
@@ -518,6 +519,18 @@ class IRContext {
std::string message = "ID overflow. Try running compact-ids.";
consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
}
+#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
+ // If TakeNextId returns 0, it is very likely that execution will
+ // subsequently fail. Such failures are false alarms from a fuzzing point
+ // of view: they are due to the fact that too many ids were used, rather
+ // than being due to an actual bug. Thus, during a fuzzing build, it is
+ // preferable to bail out when ID overflow occurs.
+ //
+ // A zero exit code is returned here because a non-zero code would cause
+ // ClusterFuzz/OSS-Fuzz to regard the termination as a crash, and spurious
+ // crash reports is what this guard aims to avoid.
+ exit(0);
+#endif
}
return next_id;
}
@@ -789,7 +802,7 @@ class IRContext {
// iterators to traverse instructions.
std::unordered_map<uint32_t, Function*> id_to_func_;
- // A bitset indicating which analyes are currently valid.
+ // A bitset indicating which analyzes are currently valid.
Analysis valid_analyses_;
// Opcodes of shader capability core executable instructions
@@ -854,8 +867,7 @@ inline IRContext::Analysis operator|(IRContext::Analysis lhs,
inline IRContext::Analysis& operator|=(IRContext::Analysis& lhs,
IRContext::Analysis rhs) {
- lhs = static_cast<IRContext::Analysis>(static_cast<int>(lhs) |
- static_cast<int>(rhs));
+ lhs = lhs | rhs;
return lhs;
}
@@ -1020,11 +1032,7 @@ void IRContext::AddCapability(std::unique_ptr<Instruction>&& c) {
}
void IRContext::AddExtension(const std::string& ext_name) {
- const auto num_chars = ext_name.size();
- // Compute num words, accommodate the terminating null character.
- const auto num_words = (num_chars + 1 + 3) / 4;
- std::vector<uint32_t> ext_words(num_words, 0u);
- std::memcpy(ext_words.data(), ext_name.data(), num_chars);
+ std::vector<uint32_t> ext_words = spvtools::utils::MakeVector(ext_name);
AddExtension(std::unique_ptr<Instruction>(
new Instruction(this, SpvOpExtension, 0u, 0u,
{{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}})));
@@ -1041,11 +1049,7 @@ void IRContext::AddExtension(std::unique_ptr<Instruction>&& e) {
}
void IRContext::AddExtInstImport(const std::string& name) {
- const auto num_chars = name.size();
- // Compute num words, accommodate the terminating null character.
- const auto num_words = (num_chars + 1 + 3) / 4;
- std::vector<uint32_t> ext_words(num_words, 0u);
- std::memcpy(ext_words.data(), name.data(), num_chars);
+ std::vector<uint32_t> ext_words = spvtools::utils::MakeVector(name);
AddExtInstImport(std::unique_ptr<Instruction>(
new Instruction(this, SpvOpExtInstImport, 0u, TakeNextId(),
{{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}})));
diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp
index a82b530e..97db9d8f 100644
--- a/source/opt/ir_loader.cpp
+++ b/source/opt/ir_loader.cpp
@@ -189,7 +189,8 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) {
module_->SetMemoryModel(std::move(spv_inst));
} else if (opcode == SpvOpEntryPoint) {
module_->AddEntryPoint(std::move(spv_inst));
- } else if (opcode == SpvOpExecutionMode) {
+ } else if (opcode == SpvOpExecutionMode ||
+ opcode == SpvOpExecutionModeId) {
module_->AddExecutionMode(std::move(spv_inst));
} else if (IsDebug1Inst(opcode)) {
module_->AddDebug1Inst(std::move(spv_inst));
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
index da9ba8cc..0c6d0c24 100644
--- a/source/opt/local_access_chain_convert_pass.cpp
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -19,6 +19,7 @@
#include "ir_builder.h"
#include "ir_context.h"
#include "iterator.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -328,8 +329,7 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
return false;
// If any extension not in allowlist, return false
for (auto& ei : get_module()->extensions()) {
- const char* extName =
- reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
+ const std::string extName = ei.GetInOperand(0).AsString();
if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
return false;
}
@@ -339,11 +339,9 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
for (auto& inst : context()->module()->ext_inst_imports()) {
assert(inst.opcode() == SpvOpExtInstImport &&
"Expecting an import of an extension's instruction set.");
- const char* extension_name =
- reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]);
- if (0 == std::strncmp(extension_name, "NonSemantic.", 12) &&
- 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100",
- 32)) {
+ const std::string extension_name = inst.GetInOperand(0).AsString();
+ if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
+ extension_name != "NonSemantic.Shader.DebugInfo.100") {
return false;
}
}
@@ -436,6 +434,7 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_KHR_integer_dot_product",
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
+ "SPV_KHR_uniform_group_instructions",
});
}
diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h
index 552062e5..a51660f1 100644
--- a/source/opt/local_access_chain_convert_pass.h
+++ b/source/opt/local_access_chain_convert_pass.h
@@ -81,7 +81,7 @@ class LocalAccessChainConvertPass : public MemPass {
std::vector<Operand>* in_opnds);
// Create a load/insert/store equivalent to a store of
- // |valId| through (constant index) access chaing |ptrInst|.
+ // |valId| through (constant index) access chain |ptrInst|.
// Append to |newInsts|. Returns true if successful.
bool GenAccessChainStoreReplacement(
const Instruction* ptrInst, uint32_t valId,
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
index 5fd4f658..33c8bdf8 100644
--- a/source/opt/local_single_block_elim_pass.cpp
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -19,6 +19,7 @@
#include <vector>
#include "source/opt/iterator.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -183,8 +184,7 @@ void LocalSingleBlockLoadStoreElimPass::Initialize() {
bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const {
// If any extension not in allowlist, return false
for (auto& ei : get_module()->extensions()) {
- const char* extName =
- reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
+ const std::string extName = ei.GetInOperand(0).AsString();
if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
return false;
}
@@ -194,11 +194,9 @@ bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const {
for (auto& inst : context()->module()->ext_inst_imports()) {
assert(inst.opcode() == SpvOpExtInstImport &&
"Expecting an import of an extension's instruction set.");
- const char* extension_name =
- reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]);
- if (0 == std::strncmp(extension_name, "NonSemantic.", 12) &&
- 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100",
- 32)) {
+ const std::string extension_name = inst.GetInOperand(0).AsString();
+ if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
+ extension_name != "NonSemantic.Shader.DebugInfo.100") {
return false;
}
}
@@ -288,6 +286,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_KHR_integer_dot_product",
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
+ "SPV_KHR_uniform_group_instructions",
});
}
diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp
index 051bcada..f22b1911 100644
--- a/source/opt/local_single_store_elim_pass.cpp
+++ b/source/opt/local_single_store_elim_pass.cpp
@@ -19,6 +19,7 @@
#include "source/cfa.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/opt/iterator.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -48,8 +49,7 @@ bool LocalSingleStoreElimPass::LocalSingleStoreElim(Function* func) {
bool LocalSingleStoreElimPass::AllExtensionsSupported() const {
// If any extension not in allowlist, return false
for (auto& ei : get_module()->extensions()) {
- const char* extName =
- reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
+ const std::string extName = ei.GetInOperand(0).AsString();
if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
return false;
}
@@ -59,11 +59,9 @@ bool LocalSingleStoreElimPass::AllExtensionsSupported() const {
for (auto& inst : context()->module()->ext_inst_imports()) {
assert(inst.opcode() == SpvOpExtInstImport &&
"Expecting an import of an extension's instruction set.");
- const char* extension_name =
- reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]);
- if (0 == std::strncmp(extension_name, "NonSemantic.", 12) &&
- 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100",
- 32)) {
+ const std::string extension_name = inst.GetInOperand(0).AsString();
+ if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
+ extension_name != "NonSemantic.Shader.DebugInfo.100") {
return false;
}
}
@@ -141,6 +139,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_KHR_integer_dot_product",
"SPV_EXT_shader_image_int64",
"SPV_KHR_non_semantic_info",
+ "SPV_KHR_uniform_group_instructions",
});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp
index b5b56309..9bc495e5 100644
--- a/source/opt/loop_descriptor.cpp
+++ b/source/opt/loop_descriptor.cpp
@@ -719,7 +719,7 @@ bool Loop::FindNumberOfIterations(const Instruction* induction,
step_value = -step_value;
}
- // Find the inital value of the loop and make sure it is a constant integer.
+ // Find the initial value of the loop and make sure it is a constant integer.
int64_t init_value = 0;
if (!GetInductionInitValue(induction, &init_value)) return false;
@@ -751,7 +751,7 @@ bool Loop::FindNumberOfIterations(const Instruction* induction,
// We retrieve the number of iterations using the following formula, diff /
// |step_value| where diff is calculated differently according to the
// |condition| and uses the |condition_value| and |init_value|. If diff /
-// |step_value| is NOT cleanly divisable then we add one to the sum.
+// |step_value| is NOT cleanly divisible then we add one to the sum.
int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
int64_t init_value, int64_t step_value) const {
int64_t diff = 0;
@@ -795,7 +795,7 @@ int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
// If the condition is not met to begin with the loop will never iterate.
if (!(init_value >= condition_value)) return 0;
- // We subract one to make it the same as SpvOpGreaterThan as it is
+ // We subtract one to make it the same as SpvOpGreaterThan as it is
// functionally equivalent.
diff = init_value - (condition_value - 1);
diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h
index 4b4f8bc7..e88ff936 100644
--- a/source/opt/loop_descriptor.h
+++ b/source/opt/loop_descriptor.h
@@ -395,7 +395,7 @@ class Loop {
// Sets |merge| as the loop merge block. No checks are performed here.
inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; }
- // Each differnt loop |condition| affects how we calculate the number of
+ // Each different loop |condition| affects how we calculate the number of
// iterations using the |condition_value|, |init_value|, and |step_values| of
// the induction variable. This method will return the number of iterations in
// a loop with those values for a given |condition|.
diff --git a/source/opt/loop_fission.cpp b/source/opt/loop_fission.cpp
index 0678113c..b4df8c62 100644
--- a/source/opt/loop_fission.cpp
+++ b/source/opt/loop_fission.cpp
@@ -29,7 +29,7 @@
// 2 - For each loop in the list, group each instruction into a set of related
// instructions by traversing each instructions users and operands recursively.
// We stop if we encounter an instruction we have seen before or an instruction
-// which we don't consider relevent (i.e OpLoopMerge). We then group these
+// which we don't consider relevant (i.e OpLoopMerge). We then group these
// groups into two different sets, one for the first loop and one for the
// second.
//
@@ -453,7 +453,7 @@ Pass::Status LoopFissionPass::Process() {
for (Function& f : *context()->module()) {
// We collect all the inner most loops in the function and run the loop
// splitting util on each. The reason we do this is to allow us to iterate
- // over each, as creating new loops will invalidate the the loop iterator.
+ // over each, as creating new loops will invalidate the loop iterator.
std::vector<Loop*> inner_most_loops{};
LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f);
for (Loop& loop : loop_descriptor) {
diff --git a/source/opt/loop_fission.h b/source/opt/loop_fission.h
index e7a59c18..9bc12c0f 100644
--- a/source/opt/loop_fission.h
+++ b/source/opt/loop_fission.h
@@ -33,7 +33,7 @@ namespace opt {
class LoopFissionPass : public Pass {
public:
- // Fuction used to determine if a given loop should be split. Takes register
+ // Function used to determine if a given loop should be split. Takes register
// pressure region for that loop as a parameter and returns true if the loop
// should be split.
using FissionCriteriaFunction =
diff --git a/source/opt/loop_fusion.cpp b/source/opt/loop_fusion.cpp
index 07d171a0..f3aab283 100644
--- a/source/opt/loop_fusion.cpp
+++ b/source/opt/loop_fusion.cpp
@@ -165,7 +165,7 @@ bool LoopFusion::AreCompatible() {
// Check adjacency, |loop_0_| should come just before |loop_1_|.
// There is always at least one block between loops, even if it's empty.
- // We'll check at most 2 preceeding blocks.
+ // We'll check at most 2 preceding blocks.
auto pre_header_1 = loop_1_->GetPreHeaderBlock();
@@ -712,7 +712,7 @@ void LoopFusion::Fuse() {
ld->RemoveLoop(loop_1_);
- // Kill unnessecary instructions and remove all empty blocks.
+ // Kill unnecessary instructions and remove all empty blocks.
for (auto inst : instr_to_delete) {
context_->KillInst(inst);
}
diff --git a/source/opt/loop_fusion.h b/source/opt/loop_fusion.h
index d61d6783..769da5f1 100644
--- a/source/opt/loop_fusion.h
+++ b/source/opt/loop_fusion.h
@@ -40,7 +40,7 @@ class LoopFusion {
// That means:
// * they both have one induction variable
// * they have the same upper and lower bounds
- // - same inital value
+ // - same initial value
// - same condition
// * they have the same update step
// * they are adjacent, with |loop_0| appearing before |loop_1|
diff --git a/source/opt/loop_fusion_pass.h b/source/opt/loop_fusion_pass.h
index 3a0be600..9d5b7ccd 100644
--- a/source/opt/loop_fusion_pass.h
+++ b/source/opt/loop_fusion_pass.h
@@ -33,7 +33,7 @@ class LoopFusionPass : public Pass {
// Processes the given |module|. Returns Status::Failure if errors occur when
// processing. Returns the corresponding Status::Success if processing is
- // succesful to indicate whether changes have been made to the modue.
+ // successful to indicate whether changes have been made to the module.
Status Process() override;
private:
diff --git a/source/opt/loop_peeling.h b/source/opt/loop_peeling.h
index 413f896f..2a55fe44 100644
--- a/source/opt/loop_peeling.h
+++ b/source/opt/loop_peeling.h
@@ -261,7 +261,7 @@ class LoopPeelingPass : public Pass {
// Processes the given |module|. Returns Status::Failure if errors occur when
// processing. Returns the corresponding Status::Success if processing is
- // succesful to indicate whether changes have been made to the modue.
+ // successful to indicate whether changes have been made to the module.
Pass::Status Process() override;
private:
diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp
index aff191fe..28ff0729 100644
--- a/source/opt/loop_unroller.cpp
+++ b/source/opt/loop_unroller.cpp
@@ -163,7 +163,7 @@ struct LoopUnrollState {
};
// This class implements the actual unrolling. It uses a LoopUnrollState to
-// maintain the state of the unrolling inbetween steps.
+// maintain the state of the unrolling in between steps.
class LoopUnrollerUtilsImpl {
public:
using BasicBlockListTy = std::vector<std::unique_ptr<BasicBlock>>;
@@ -209,7 +209,7 @@ class LoopUnrollerUtilsImpl {
// Add all blocks_to_add_ to function_ at the |insert_point|.
void AddBlocksToFunction(const BasicBlock* insert_point);
- // Duplicates the |old_loop|, cloning each body and remaping the ids without
+ // Duplicates the |old_loop|, cloning each body and remapping the ids without
// removing instructions or changing relative structure. Result will be stored
// in |new_loop|.
void DuplicateLoop(Loop* old_loop, Loop* new_loop);
@@ -241,7 +241,7 @@ class LoopUnrollerUtilsImpl {
// Remap all the in |basic_block| to new IDs and keep the mapping of new ids
// to old
// ids. |loop| is used to identify special loop blocks (header, continue,
- // ect).
+ // etc).
void AssignNewResultIds(BasicBlock* basic_block);
// Using the map built by AssignNewResultIds, replace the uses in |inst|
@@ -320,7 +320,7 @@ class LoopUnrollerUtilsImpl {
// and then be remapped at the end.
std::vector<Instruction*> loop_phi_instructions_;
- // The number of loop iterations that the loop would preform pre-unroll.
+ // The number of loop iterations that the loop would perform pre-unroll.
size_t number_of_loop_iterations_;
// The amount that the loop steps each iteration.
@@ -839,7 +839,7 @@ void LoopUnrollerUtilsImpl::DuplicateLoop(Loop* old_loop, Loop* new_loop) {
new_loop->SetMergeBlock(new_merge);
}
-// Whenever the utility copies a block it stores it in a tempory buffer, this
+// Whenever the utility copies a block it stores it in a temporary buffer, this
// function adds the buffer into the Function. The blocks will be inserted
// after the block |insert_point|.
void LoopUnrollerUtilsImpl::AddBlocksToFunction(
diff --git a/source/opt/loop_unswitch_pass.cpp b/source/opt/loop_unswitch_pass.cpp
index d805ecf3..1ee7e5e2 100644
--- a/source/opt/loop_unswitch_pass.cpp
+++ b/source/opt/loop_unswitch_pass.cpp
@@ -118,7 +118,7 @@ class LoopUnswitch {
// Find a value that can be used to select the default path.
// If none are possible, then it will just use 0. The value does not matter
- // because this path will never be taken becaues the new switch outside of
+ // because this path will never be taken because the new switch outside of
// the loop cannot select this path either.
std::vector<uint32_t> existing_values;
for (uint32_t i = 2; i < switch_inst->NumInOperands(); i += 2) {
diff --git a/source/opt/loop_unswitch_pass.h b/source/opt/loop_unswitch_pass.h
index 3ecdd611..4f7295d4 100644
--- a/source/opt/loop_unswitch_pass.h
+++ b/source/opt/loop_unswitch_pass.h
@@ -30,7 +30,7 @@ class LoopUnswitchPass : public Pass {
// Processes the given |module|. Returns Status::Failure if errors occur when
// processing. Returns the corresponding Status::Success if processing is
- // succesful to indicate whether changes have been made to the modue.
+ // successful to indicate whether changes have been made to the module.
Pass::Status Process() override;
private:
diff --git a/source/opt/loop_utils.h b/source/opt/loop_utils.h
index a4e61900..70060fc4 100644
--- a/source/opt/loop_utils.h
+++ b/source/opt/loop_utils.h
@@ -123,7 +123,7 @@ class LoopUtils {
// Clone the |loop_| and make the new loop branch to the second loop on exit.
Loop* CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result);
- // Perfom a partial unroll of |loop| by given |factor|. This will copy the
+ // Perform a partial unroll of |loop| by given |factor|. This will copy the
// body of the loop |factor| times. So a |factor| of one would give a new loop
// with the original body plus one unrolled copy body.
bool PartiallyUnroll(size_t factor);
@@ -139,7 +139,7 @@ class LoopUtils {
// 1. That the loop is in structured order.
// 2. That the continue block is a branch to the header.
// 3. That the only phi used in the loop is the induction variable.
- // TODO(stephen@codeplay.com): This is a temporary mesure, after the loop is
+ // TODO(stephen@codeplay.com): This is a temporary measure, after the loop is
// converted into LCSAA form and has a single entry and exit we can rewrite
// the other phis.
// 4. That this is an inner most loop, or that loops contained within this
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index a962a7cc..7710deae 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -431,6 +431,7 @@ bool MergeReturnPass::BreakFromConstruct(
std::list<BasicBlock*>* order, Instruction* break_merge_inst) {
// Make sure the CFG is build here. If we don't then it becomes very hard
// to know which new blocks need to be updated.
+ context()->InvalidateAnalyses(IRContext::kAnalysisCFG);
context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG);
// When predicating, be aware of whether this block is a header block, a
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index 4096ce7d..a35cf269 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -247,7 +247,7 @@ class MergeReturnPass : public MemPass {
// Add new phi nodes for any id that no longer dominate all of it uses. A phi
// node is added to a block |bb| for an id if the id is defined between the
- // original immediate dominator of |bb| and its new immidiate dominator. It
+ // original immediate dominator of |bb| and its new immediate dominator. It
// is assumed that at this point there are no unreachable blocks in the
// control flow graph.
void AddNewPhiNodes();
@@ -273,7 +273,7 @@ class MergeReturnPass : public MemPass {
void InsertAfterElement(BasicBlock* element, BasicBlock* new_element,
std::list<BasicBlock*>* list);
- // Creates a single case switch around all of the exectuable code of the
+ // Creates a single case switch around all of the executable code of the
// current function where the switch and case value are both zero and the
// default is the merge block. Returns after the switch is executed. Sets
// |final_return_block_|.
diff --git a/source/opt/module.cpp b/source/opt/module.cpp
index f97defbd..5983abb1 100644
--- a/source/opt/module.cpp
+++ b/source/opt/module.cpp
@@ -139,7 +139,7 @@ void Module::ToBinary(std::vector<uint32_t>* binary, bool skip_nop) const {
// TODO(antiagainst): should we change the generator number?
binary->push_back(header_.generator);
binary->push_back(header_.bound);
- binary->push_back(header_.reserved);
+ binary->push_back(header_.schema);
size_t bound_idx = binary->size() - 2;
DebugScope last_scope(kNoDebugScope, kNoInlinedAt);
@@ -260,9 +260,7 @@ bool Module::HasExplicitCapability(uint32_t cap) {
uint32_t Module::GetExtInstImportId(const char* extstr) {
for (auto& ei : ext_inst_imports_)
- if (!strcmp(extstr,
- reinterpret_cast<const char*>(&(ei.GetInOperand(0).words[0]))))
- return ei.result_id();
+ if (!ei.GetInOperand(0).AsString().compare(extstr)) return ei.result_id();
return 0;
}
diff --git a/source/opt/module.h b/source/opt/module.h
index 0360b7d5..230be709 100644
--- a/source/opt/module.h
+++ b/source/opt/module.h
@@ -36,7 +36,7 @@ struct ModuleHeader {
uint32_t version;
uint32_t generator;
uint32_t bound;
- uint32_t reserved;
+ uint32_t schema;
};
// A SPIR-V module. It contains all the information for a SPIR-V module and
@@ -61,7 +61,7 @@ class Module {
}
// Returns the Id bound.
- uint32_t IdBound() { return header_.bound; }
+ uint32_t IdBound() const { return header_.bound; }
// Returns the current Id bound and increases it to the next available value.
// If the id bound has already reached its maximum value, then 0 is returned.
@@ -141,6 +141,8 @@ class Module {
inline uint32_t id_bound() const { return header_.bound; }
inline uint32_t version() const { return header_.version; }
+ inline uint32_t generator() const { return header_.generator; }
+ inline uint32_t schema() const { return header_.schema; }
inline void set_version(uint32_t v) { header_.version = v; }
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index e74db26f..f28b1baf 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -288,6 +288,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
RegisterPass(CreateStripDebugInfoPass());
} else if (pass_name == "strip-reflect") {
RegisterPass(CreateStripReflectInfoPass());
+ } else if (pass_name == "strip-nonsemantic") {
+ RegisterPass(CreateStripNonSemanticInfoPass());
} else if (pass_name == "set-spec-const-default-value") {
if (pass_args.size() > 0) {
auto spec_ids_vals =
@@ -322,6 +324,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
RegisterPass(CreateLocalAccessChainConvertPass());
} else if (pass_name == "replace-desc-array-access-using-var-index") {
RegisterPass(CreateReplaceDescArrayAccessUsingVarIndexPass());
+ } else if (pass_name == "spread-volatile-semantics") {
+ RegisterPass(CreateSpreadVolatileSemanticsPass());
} else if (pass_name == "descriptor-scalar-replacement") {
RegisterPass(CreateDescriptorScalarReplacementPass());
} else if (pass_name == "eliminate-dead-code-aggressive") {
@@ -517,6 +521,10 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
RegisterPass(CreateAmdExtToKhrPass());
} else if (pass_name == "interpolate-fixup") {
RegisterPass(CreateInterpolateFixupPass());
+ } else if (pass_name == "remove-dont-inline") {
+ RegisterPass(CreateRemoveDontInlinePass());
+ } else if (pass_name == "eliminate-dead-input-components") {
+ RegisterPass(CreateEliminateDeadInputComponentsPass());
} else if (pass_name == "convert-to-sampled-image") {
if (pass_args.size() > 0) {
auto descriptor_set_binding_pairs =
@@ -653,8 +661,12 @@ Optimizer::PassToken CreateStripDebugInfoPass() {
}
Optimizer::PassToken CreateStripReflectInfoPass() {
+ return CreateStripNonSemanticInfoPass();
+}
+
+Optimizer::PassToken CreateStripNonSemanticInfoPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
- MakeUnique<opt::StripReflectInfoPass>());
+ MakeUnique<opt::StripNonSemanticInfoPass>());
}
Optimizer::PassToken CreateEliminateDeadFunctionsPass() {
@@ -764,6 +776,11 @@ Optimizer::PassToken CreateLocalMultiStoreElimPass() {
MakeUnique<opt::SSARewritePass>());
}
+Optimizer::PassToken CreateAggressiveDCEPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::AggressiveDCEPass>(false));
+}
+
Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface) {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::AggressiveDCEPass>(preserve_interface));
@@ -965,6 +982,11 @@ Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass() {
MakeUnique<opt::ReplaceDescArrayAccessUsingVarIndex>());
}
+Optimizer::PassToken CreateSpreadVolatileSemanticsPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::SpreadVolatileSemantics>());
+}
+
Optimizer::PassToken CreateDescriptorScalarReplacementPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::DescriptorScalarReplacement>());
@@ -984,6 +1006,11 @@ Optimizer::PassToken CreateInterpolateFixupPass() {
MakeUnique<opt::InterpFixupPass>());
}
+Optimizer::PassToken CreateEliminateDeadInputComponentsPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::EliminateDeadInputComponentsPass>());
+}
+
Optimizer::PassToken CreateConvertToSampledImagePass(
const std::vector<opt::DescriptorSetAndBinding>&
descriptor_set_binding_pairs) {
@@ -991,4 +1018,8 @@ Optimizer::PassToken CreateConvertToSampledImagePass(
MakeUnique<opt::ConvertToSampledImagePass>(descriptor_set_binding_pairs));
}
+Optimizer::PassToken CreateRemoveDontInlinePass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::RemoveDontInline>());
+}
} // namespace spvtools
diff --git a/source/opt/pass.h b/source/opt/pass.h
index a8c9c4b4..b2303e23 100644
--- a/source/opt/pass.h
+++ b/source/opt/pass.h
@@ -28,6 +28,13 @@
#include "spirv-tools/libspirv.hpp"
#include "types.h"
+// Avoid unused variable warning/error on Linux
+#ifndef NDEBUG
+#define USE_ASSERT(x) assert(x)
+#else
+#define USE_ASSERT(x) ((void)(x))
+#endif
+
namespace spvtools {
namespace opt {
@@ -129,7 +136,7 @@ class Pass {
// Processes the given |module|. Returns Status::Failure if errors occur when
// processing. Returns the corresponding Status::Success if processing is
- // succesful to indicate whether changes are made to the module.
+ // successful to indicate whether changes are made to the module.
virtual Status Process() = 0;
// Return the next available SSA id and increment it.
diff --git a/source/opt/pass_manager.cpp b/source/opt/pass_manager.cpp
index be53d344..a73ff7cf 100644
--- a/source/opt/pass_manager.cpp
+++ b/source/opt/pass_manager.cpp
@@ -35,10 +35,18 @@ Pass::Status PassManager::Run(IRContext* context) {
if (print_all_stream_) {
std::vector<uint32_t> binary;
context->module()->ToBinary(&binary, false);
- SpirvTools t(SPV_ENV_UNIVERSAL_1_2);
+ SpirvTools t(target_env_);
+ t.SetMessageConsumer(consumer());
std::string disassembly;
- t.Disassemble(binary, &disassembly, 0);
- *print_all_stream_ << preamble << (pass ? pass->name() : "") << "\n"
+ std::string pass_name = (pass ? pass->name() : "");
+ if (!t.Disassemble(binary, &disassembly, 0)) {
+ std::string msg = "Disassembly failed before pass ";
+ msg += pass_name + "\n";
+ spv_position_t null_pos{0, 0, 0};
+ consumer()(SPV_MSG_WARNING, "", null_pos, msg.c_str());
+ return;
+ }
+ *print_all_stream_ << preamble << pass_name << "\n"
<< disassembly << std::endl;
}
};
diff --git a/source/opt/pass_manager.h b/source/opt/pass_manager.h
index 9686dddc..11961a33 100644
--- a/source/opt/pass_manager.h
+++ b/source/opt/pass_manager.h
@@ -54,7 +54,7 @@ class PassManager {
// Adds an externally constructed pass.
void AddPass(std::unique_ptr<Pass> pass);
// Uses the argument |args| to construct a pass instance of type |T|, and adds
- // the pass instance to this pass manger. The pass added will use this pass
+ // the pass instance to this pass manager. The pass added will use this pass
// manager's message consumer.
template <typename T, typename... Args>
void AddPass(Args&&... args);
@@ -70,7 +70,7 @@ class PassManager {
// Runs all passes on the given |module|. Returns Status::Failure if errors
// occur when processing using one of the registered passes. All passes
// registered after the error-reporting pass will be skipped. Returns the
- // corresponding Status::Success if processing is succesful to indicate
+ // corresponding Status::Success if processing is successful to indicate
// whether changes are made to the module.
//
// After running all the passes, they are removed from the list.
diff --git a/source/opt/passes.h b/source/opt/passes.h
index f3c30d57..a12c76b8 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -34,6 +34,7 @@
#include "source/opt/desc_sroa.h"
#include "source/opt/eliminate_dead_constant_pass.h"
#include "source/opt/eliminate_dead_functions_pass.h"
+#include "source/opt/eliminate_dead_input_components_pass.h"
#include "source/opt/eliminate_dead_members_pass.h"
#include "source/opt/empty_pass.h"
#include "source/opt/fix_storage_class.h"
@@ -64,6 +65,7 @@
#include "source/opt/reduce_load_size.h"
#include "source/opt/redundancy_elimination.h"
#include "source/opt/relax_float_ops_pass.h"
+#include "source/opt/remove_dontinline_pass.h"
#include "source/opt/remove_duplicates_pass.h"
#include "source/opt/remove_unused_interface_variables_pass.h"
#include "source/opt/replace_desc_array_access_using_var_index.h"
@@ -71,10 +73,11 @@
#include "source/opt/scalar_replacement_pass.h"
#include "source/opt/set_spec_constant_default_value_pass.h"
#include "source/opt/simplification_pass.h"
+#include "source/opt/spread_volatile_semantics.h"
#include "source/opt/ssa_rewrite_pass.h"
#include "source/opt/strength_reduction_pass.h"
#include "source/opt/strip_debug_info_pass.h"
-#include "source/opt/strip_reflect_info_pass.h"
+#include "source/opt/strip_nonsemantic_info_pass.h"
#include "source/opt/unify_const_pass.h"
#include "source/opt/upgrade_memory_model.h"
#include "source/opt/vector_dce.h"
diff --git a/source/opt/private_to_local_pass.cpp b/source/opt/private_to_local_pass.cpp
index 12a226d5..80fb4c53 100644
--- a/source/opt/private_to_local_pass.cpp
+++ b/source/opt/private_to_local_pass.cpp
@@ -135,7 +135,7 @@ bool PrivateToLocalPass::MoveVariable(Instruction* variable,
// Place the variable at the start of the first basic block.
context()->AnalyzeUses(variable);
context()->set_instr_block(variable, &*function->begin());
- function->begin()->begin()->InsertBefore(move(var));
+ function->begin()->begin()->InsertBefore(std::move(var));
// Update uses where the type may have changed.
return UpdateUses(variable);
diff --git a/source/opt/private_to_local_pass.h b/source/opt/private_to_local_pass.h
index c6127d67..e96a965e 100644
--- a/source/opt/private_to_local_pass.h
+++ b/source/opt/private_to_local_pass.h
@@ -44,7 +44,7 @@ class PrivateToLocalPass : public Pass {
// class of |function|. Returns false if the variable could not be moved.
bool MoveVariable(Instruction* variable, Function* function);
- // |inst| is an instruction declaring a varible. If that variable is
+ // |inst| is an instruction declaring a variable. If that variable is
// referenced in a single function and all of uses are valid as defined by
// |IsValidUse|, then that function is returned. Otherwise, the return
// value is |nullptr|.
diff --git a/source/opt/redundancy_elimination.h b/source/opt/redundancy_elimination.h
index 91809b5d..40451f40 100644
--- a/source/opt/redundancy_elimination.h
+++ b/source/opt/redundancy_elimination.h
@@ -41,7 +41,7 @@ class RedundancyEliminationPass : public LocalRedundancyEliminationPass {
// in the function containing |bb|.
//
// |value_to_ids| is a map from value number to ids. If {vn, id} is in
- // |value_to_ids| then vn is the value number of id, and the defintion of id
+ // |value_to_ids| then vn is the value number of id, and the definition of id
// dominates |bb|.
//
// Returns true if at least one instruction is deleted.
diff --git a/source/opt/register_pressure.cpp b/source/opt/register_pressure.cpp
index 5750c6d4..1ad33738 100644
--- a/source/opt/register_pressure.cpp
+++ b/source/opt/register_pressure.cpp
@@ -378,7 +378,7 @@ void RegisterLiveness::SimulateFusion(
// The loop fusion is injecting the l1 before the l2, the latch of l1 will be
// connected to the header of l2.
// To compute the register usage, we inject the loop live-in (union of l1 and
- // l2 live-in header blocks) into the the live in/out of each basic block of
+ // l2 live-in header blocks) into the live in/out of each basic block of
// l1 to get the peak register usage. We then repeat the operation to for l2
// basic blocks but in this case we inject the live-out of the latch of l1.
auto live_loop = MakeFilterIteratorRange(
diff --git a/source/opt/remove_dontinline_pass.cpp b/source/opt/remove_dontinline_pass.cpp
new file mode 100644
index 00000000..4dd1cd4f
--- /dev/null
+++ b/source/opt/remove_dontinline_pass.cpp
@@ -0,0 +1,49 @@
+// Copyright (c) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/remove_dontinline_pass.h"
+
+namespace spvtools {
+namespace opt {
+
+Pass::Status RemoveDontInline::Process() {
+ bool modified = false;
+ modified = ClearDontInlineFunctionControl();
+ return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
+}
+
+bool RemoveDontInline::ClearDontInlineFunctionControl() {
+ bool modified = false;
+ for (auto& func : *get_module()) {
+ ClearDontInlineFunctionControl(&func);
+ }
+ return modified;
+}
+
+bool RemoveDontInline::ClearDontInlineFunctionControl(Function* function) {
+ constexpr uint32_t kFunctionControlInOperandIdx = 0;
+ Instruction* function_inst = &function->DefInst();
+ uint32_t function_control =
+ function_inst->GetSingleWordInOperand(kFunctionControlInOperandIdx);
+
+ if ((function_control & SpvFunctionControlDontInlineMask) == 0) {
+ return false;
+ }
+ function_control &= ~SpvFunctionControlDontInlineMask;
+ function_inst->SetInOperand(kFunctionControlInOperandIdx, {function_control});
+ return true;
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/remove_dontinline_pass.h b/source/opt/remove_dontinline_pass.h
new file mode 100644
index 00000000..16243199
--- /dev/null
+++ b/source/opt/remove_dontinline_pass.h
@@ -0,0 +1,42 @@
+// Copyright (c) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_REMOVE_DONTINLINE_PASS_H_
+#define SOURCE_OPT_REMOVE_DONTINLINE_PASS_H_
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class RemoveDontInline : public Pass {
+ public:
+ const char* name() const override { return "remove-dont-inline"; }
+ Status Process() override;
+
+ private:
+ // Clears the DontInline function control from every function in the module.
+ // Returns true of a change was made.
+ bool ClearDontInlineFunctionControl();
+
+ // Clears the DontInline function control from |function|.
+ // Returns true of a change was made.
+ bool ClearDontInlineFunctionControl(Function* function);
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_REMOVE_DONTINLINE_PASS_H_
diff --git a/source/opt/remove_duplicates_pass.cpp b/source/opt/remove_duplicates_pass.cpp
index 0e65cc8d..1ed8e2a0 100644
--- a/source/opt/remove_duplicates_pass.cpp
+++ b/source/opt/remove_duplicates_pass.cpp
@@ -72,9 +72,8 @@ bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports() const {
std::unordered_map<std::string, SpvId> ext_inst_imports;
for (auto* i = &*context()->ext_inst_import_begin(); i;) {
- auto res = ext_inst_imports.emplace(
- reinterpret_cast<const char*>(i->GetInOperand(0u).words.data()),
- i->result_id());
+ auto res = ext_inst_imports.emplace(i->GetInOperand(0u).AsString(),
+ i->result_id());
if (res.second) {
// Never seen before, keep it.
i = i->NextNode();
diff --git a/source/opt/replace_desc_array_access_using_var_index.cpp b/source/opt/replace_desc_array_access_using_var_index.cpp
index 1082e679..4cadf600 100644
--- a/source/opt/replace_desc_array_access_using_var_index.cpp
+++ b/source/opt/replace_desc_array_access_using_var_index.cpp
@@ -253,8 +253,12 @@ void ReplaceDescArrayAccessUsingVarIndex::ReplaceNonUniformAccessWithSwitchCase(
Instruction* access_chain_final_user, Instruction* access_chain,
uint32_t number_of_elements,
const std::deque<Instruction*>& insts_to_be_cloned) const {
- // Create merge block and add terminator
auto* block = context()->get_instr_block(access_chain_final_user);
+ // If the instruction does not belong to a block (i.e. in the case of
+ // OpDecorate), no replacement is needed.
+ if (!block) return;
+
+ // Create merge block and add terminator
auto* merge_block = SeparateInstructionsIntoNewBlock(
block, access_chain_final_user->NextNode());
diff --git a/source/opt/replace_desc_array_access_using_var_index.h b/source/opt/replace_desc_array_access_using_var_index.h
index e18222c8..0c97f7eb 100644
--- a/source/opt/replace_desc_array_access_using_var_index.h
+++ b/source/opt/replace_desc_array_access_using_var_index.h
@@ -47,7 +47,7 @@ class ReplaceDescArrayAccessUsingVarIndex : public Pass {
}
private:
- // Replaces all acceses to |var| using variable indices with constant
+ // Replaces all accesses to |var| using variable indices with constant
// elements of the array |var|. Creates switch-case statements to determine
// the value of the variable index for all the possible cases. Returns
// whether replacement is done or not.
@@ -170,7 +170,7 @@ class ReplaceDescArrayAccessUsingVarIndex : public Pass {
// Creates and adds an OpSwitch used for the selection of OpAccessChain whose
// first Indexes operand is |access_chain_index_var_id|. The OpSwitch will be
// added at the end of |parent_block|. It will jump to |default_id| for the
- // default case and jumps to one of case blocks whoes ids are |case_block_ids|
+ // default case and jumps to one of case blocks whose ids are |case_block_ids|
// if |access_chain_index_var_id| matches the case number. |merge_id| is the
// merge block id.
void AddSwitchForAccessChain(
diff --git a/source/opt/replace_invalid_opc.cpp b/source/opt/replace_invalid_opc.cpp
index e3b9d3e4..1dcd06f5 100644
--- a/source/opt/replace_invalid_opc.cpp
+++ b/source/opt/replace_invalid_opc.cpp
@@ -112,8 +112,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function,
}
Instruction* file_name =
context()->get_def_use_mgr()->GetDef(file_name_id);
- const char* source = reinterpret_cast<const char*>(
- &file_name->GetInOperand(0).words[0]);
+ const std::string source = file_name->GetInOperand(0).AsString();
// Get the line number and column number.
uint32_t line_number =
@@ -121,7 +120,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function,
uint32_t col_number = last_line_dbg_inst->GetSingleWordInOperand(2);
// Replace the instruction.
- ReplaceInstruction(inst, source, line_number, col_number);
+ ReplaceInstruction(inst, source.c_str(), line_number, col_number);
}
}
},
diff --git a/source/opt/scalar_analysis.cpp b/source/opt/scalar_analysis.cpp
index 38555e64..2b0a824c 100644
--- a/source/opt/scalar_analysis.cpp
+++ b/source/opt/scalar_analysis.cpp
@@ -581,7 +581,7 @@ static void PushToString(T id, std::u32string* str) {
// Implements the hashing of SENodes.
size_t SENodeHash::operator()(const SENode* node) const {
- // Concatinate the terms into a string which we can hash.
+ // Concatenate the terms into a string which we can hash.
std::u32string hash_string{};
// Hashing the type as a string is safer than hashing the enum as the enum is
diff --git a/source/opt/scalar_analysis_nodes.h b/source/opt/scalar_analysis_nodes.h
index b0e3fefd..91ce446f 100644
--- a/source/opt/scalar_analysis_nodes.h
+++ b/source/opt/scalar_analysis_nodes.h
@@ -167,7 +167,7 @@ class SENode {
const ChildContainerType& GetChildren() const { return children_; }
ChildContainerType& GetChildren() { return children_; }
- // Return true if this node is a cant compute node.
+ // Return true if this node is a can't compute node.
bool IsCantCompute() const { return GetType() == CanNotCompute; }
// Implements a casting method for each type.
diff --git a/source/opt/scalar_analysis_simplification.cpp b/source/opt/scalar_analysis_simplification.cpp
index 52f2d6ad..3c1ecc08 100644
--- a/source/opt/scalar_analysis_simplification.cpp
+++ b/source/opt/scalar_analysis_simplification.cpp
@@ -88,7 +88,7 @@ class SENodeSimplifyImpl {
private:
// Recursively descend through the graph to build up the accumulator objects
- // which are used to flatten the graph. |child| is the node currenty being
+ // which are used to flatten the graph. |child| is the node currently being
// traversed and the |negation| flag is used to signify that this operation
// was preceded by a unary negative operation and as such the result should be
// negated.
@@ -134,7 +134,7 @@ class SENodeSimplifyImpl {
// offset.
SENode* EliminateZeroCoefficientRecurrents(SENode* node);
- // A reference the the analysis which requested the simplification.
+ // A reference the analysis which requested the simplification.
ScalarEvolutionAnalysis& analysis_;
// The node being simplified.
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index 4d6a7aad..e27c828b 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -24,6 +24,7 @@
#include "source/opt/reflect.h"
#include "source/opt/types.h"
#include "source/util/make_unique.h"
+#include "types.h"
static const uint32_t kDebugValueOperandValueIndex = 5;
static const uint32_t kDebugValueOperandExpressionIndex = 6;
@@ -395,7 +396,7 @@ bool ScalarReplacementPass::CreateReplacementVariables(
if (!components_used || components_used->count(elem)) {
CreateVariable(*id, inst, elem, replacements);
} else {
- replacements->push_back(CreateNullConstant(*id));
+ replacements->push_back(GetUndef(*id));
}
elem++;
});
@@ -406,8 +407,8 @@ bool ScalarReplacementPass::CreateReplacementVariables(
CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
replacements);
} else {
- replacements->push_back(
- CreateNullConstant(type->GetSingleWordInOperand(0u)));
+ uint32_t element_type_id = type->GetSingleWordInOperand(0);
+ replacements->push_back(GetUndef(element_type_id));
}
}
break;
@@ -429,6 +430,10 @@ bool ScalarReplacementPass::CreateReplacementVariables(
replacements->end();
}
+Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) {
+ return get_def_use_mgr()->GetDef(Type2Undef(type_id));
+}
+
void ScalarReplacementPass::TransferAnnotations(
const Instruction* source, std::vector<Instruction*>* replacements) {
// Only transfer invariant and restrict decorations on the variable. There are
@@ -981,20 +986,6 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
return result;
}
-Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
- analysis::TypeManager* type_mgr = context()->get_type_mgr();
- analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
-
- const analysis::Type* type = type_mgr->GetType(type_id);
- const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
- Instruction* null_inst =
- const_mgr->GetDefiningInstruction(null_const, type_id);
- if (null_inst != nullptr) {
- context()->UpdateDefUse(null_inst);
- }
- return null_inst;
-}
-
uint64_t ScalarReplacementPass::GetMaxLegalIndex(
const Instruction* var_inst) const {
assert(var_inst->opcode() == SpvOpVariable &&
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
index 0928830c..76afc267 100644
--- a/source/opt/scalar_replacement_pass.h
+++ b/source/opt/scalar_replacement_pass.h
@@ -23,14 +23,14 @@
#include <vector>
#include "source/opt/function.h"
-#include "source/opt/pass.h"
+#include "source/opt/mem_pass.h"
#include "source/opt/type_manager.h"
namespace spvtools {
namespace opt {
// Documented in optimizer.hpp
-class ScalarReplacementPass : public Pass {
+class ScalarReplacementPass : public MemPass {
private:
static const uint32_t kDefaultLimit = 100;
@@ -234,10 +234,8 @@ class ScalarReplacementPass : public Pass {
std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
Instruction* inst);
- // Returns an instruction defining a null constant with type |type_id|. If
- // one already exists, it is returned. Otherwise a new one is created.
- // Returns |nullptr| if the new constant could not be created.
- Instruction* CreateNullConstant(uint32_t type_id);
+ // Returns an instruction defining an undefined value type |type_id|.
+ Instruction* GetUndef(uint32_t type_id);
// Maps storage type to a pointer type enclosing that type.
std::unordered_map<uint32_t, uint32_t> pointee_to_pointer_;
diff --git a/source/opt/spread_volatile_semantics.cpp b/source/opt/spread_volatile_semantics.cpp
new file mode 100644
index 00000000..a1d34329
--- /dev/null
+++ b/source/opt/spread_volatile_semantics.cpp
@@ -0,0 +1,318 @@
+// Copyright (c) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/spread_volatile_semantics.h"
+
+#include "source/opt/decoration_manager.h"
+#include "source/opt/ir_builder.h"
+#include "source/spirv_constant.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+const uint32_t kOpDecorateInOperandBuiltinDecoration = 2u;
+const uint32_t kOpLoadInOperandMemoryOperands = 1u;
+const uint32_t kOpEntryPointInOperandEntryPoint = 1u;
+const uint32_t kOpEntryPointInOperandInterface = 3u;
+
+bool HasBuiltinDecoration(analysis::DecorationManager* decoration_manager,
+ uint32_t var_id, uint32_t built_in) {
+ return decoration_manager->FindDecoration(
+ var_id, SpvDecorationBuiltIn, [built_in](const Instruction& inst) {
+ return built_in == inst.GetSingleWordInOperand(
+ kOpDecorateInOperandBuiltinDecoration);
+ });
+}
+
+bool IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in) {
+ switch (built_in) {
+ case SpvBuiltInSMIDNV:
+ case SpvBuiltInWarpIDNV:
+ case SpvBuiltInSubgroupSize:
+ case SpvBuiltInSubgroupLocalInvocationId:
+ case SpvBuiltInSubgroupEqMask:
+ case SpvBuiltInSubgroupGeMask:
+ case SpvBuiltInSubgroupGtMask:
+ case SpvBuiltInSubgroupLeMask:
+ case SpvBuiltInSubgroupLtMask:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool HasBuiltinForRayTracingVolatileSemantics(
+ analysis::DecorationManager* decoration_manager, uint32_t var_id) {
+ return decoration_manager->FindDecoration(
+ var_id, SpvDecorationBuiltIn, [](const Instruction& inst) {
+ uint32_t built_in =
+ inst.GetSingleWordInOperand(kOpDecorateInOperandBuiltinDecoration);
+ return IsBuiltInForRayTracingVolatileSemantics(built_in);
+ });
+}
+
+bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager,
+ uint32_t var_id) {
+ return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
+}
+
+bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) {
+ std::unordered_set<uint32_t> entry_function_ids;
+ for (Instruction& entry_point : module->entry_points()) {
+ entry_function_ids.insert(
+ entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint));
+ }
+ for (auto& function : *module) {
+ if (entry_function_ids.find(function.result_id()) ==
+ entry_function_ids.end()) {
+ std::string message(
+ "Functions of SPIR-V for spread-volatile-semantics pass input must "
+ "be inlined except entry points");
+ message += "\n " + function.DefInst().PrettyPrint(
+ SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
+ context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+Pass::Status SpreadVolatileSemantics::Process() {
+ if (HasNoExecutionModel()) {
+ return Status::SuccessWithoutChange;
+ }
+
+ if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) {
+ return Status::Failure;
+ }
+
+ const bool is_vk_memory_model_enabled =
+ context()->get_feature_mgr()->HasCapability(
+ SpvCapabilityVulkanMemoryModel);
+ CollectTargetsForVolatileSemantics(is_vk_memory_model_enabled);
+
+ // If VulkanMemoryModel capability is not enabled, we have to set Volatile
+ // decoration for interface variables instead of setting Volatile for load
+ // instructions. If an interface (or pointers to it) is used by two load
+ // instructions in two entry points and one must be volatile while another
+ // is not, we have to report an error for the conflict.
+ if (!is_vk_memory_model_enabled &&
+ HasInterfaceInConflictOfVolatileSemantics()) {
+ return Status::Failure;
+ }
+
+ return SpreadVolatileSemanticsToVariables(is_vk_memory_model_enabled);
+}
+
+Pass::Status SpreadVolatileSemantics::SpreadVolatileSemanticsToVariables(
+ const bool is_vk_memory_model_enabled) {
+ Status status = Status::SuccessWithoutChange;
+ for (Instruction& var : context()->types_values()) {
+ auto entry_function_ids =
+ EntryFunctionsToSpreadVolatileSemanticsForVar(var.result_id());
+ if (entry_function_ids.empty()) {
+ continue;
+ }
+
+ if (is_vk_memory_model_enabled) {
+ SetVolatileForLoadsInEntries(&var, entry_function_ids);
+ } else {
+ DecorateVarWithVolatile(&var);
+ }
+ status = Status::SuccessWithChange;
+ }
+ return status;
+}
+
+bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
+ uint32_t var_id, Instruction* entry_point) {
+ uint32_t entry_function_id =
+ entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
+ return !VisitLoadsOfPointersToVariableInEntries(
+ var_id,
+ [](Instruction* load) {
+ // If it has a load without volatile memory operand, finish traversal
+ // and return false.
+ if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
+ return false;
+ }
+ uint32_t memory_operands =
+ load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
+ return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
+ },
+ {entry_function_id});
+}
+
+bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
+ for (Instruction& entry_point : get_module()->entry_points()) {
+ SpvExecutionModel execution_model =
+ static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
+ for (uint32_t operand_index = kOpEntryPointInOperandInterface;
+ operand_index < entry_point.NumInOperands(); ++operand_index) {
+ uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
+ if (!EntryFunctionsToSpreadVolatileSemanticsForVar(var_id).empty() &&
+ !IsTargetForVolatileSemantics(var_id, execution_model) &&
+ IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
+ Instruction* inst = context()->get_def_use_mgr()->GetDef(var_id);
+ context()->EmitErrorMessage(
+ "Variable is a target for Volatile semantics for an entry point, "
+ "but it is not for another entry point",
+ inst);
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+void SpreadVolatileSemantics::MarkVolatileSemanticsForVariable(
+ uint32_t var_id, Instruction* entry_point) {
+ uint32_t entry_function_id =
+ entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
+ auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id);
+ if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) {
+ var_ids_to_entry_fn_for_volatile_semantics_[var_id] = {entry_function_id};
+ return;
+ }
+ itr->second.insert(entry_function_id);
+}
+
+void SpreadVolatileSemantics::CollectTargetsForVolatileSemantics(
+ const bool is_vk_memory_model_enabled) {
+ for (Instruction& entry_point : get_module()->entry_points()) {
+ SpvExecutionModel execution_model =
+ static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
+ for (uint32_t operand_index = kOpEntryPointInOperandInterface;
+ operand_index < entry_point.NumInOperands(); ++operand_index) {
+ uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index);
+ if (!IsTargetForVolatileSemantics(var_id, execution_model)) {
+ continue;
+ }
+ if (is_vk_memory_model_enabled ||
+ IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) {
+ MarkVolatileSemanticsForVariable(var_id, &entry_point);
+ }
+ }
+ }
+}
+
+void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) {
+ analysis::DecorationManager* decoration_manager =
+ context()->get_decoration_mgr();
+ uint32_t var_id = var->result_id();
+ if (HasVolatileDecoration(decoration_manager, var_id)) {
+ return;
+ }
+ get_decoration_mgr()->AddDecoration(
+ SpvOpDecorate, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
+ {SpvDecorationVolatile}}});
+}
+
+bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
+ uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
+ const std::unordered_set<uint32_t>& entry_function_ids) {
+ std::vector<uint32_t> worklist({var_id});
+ auto* def_use_mgr = context()->get_def_use_mgr();
+ while (!worklist.empty()) {
+ uint32_t ptr_id = worklist.back();
+ worklist.pop_back();
+ bool finish_traversal = !def_use_mgr->WhileEachUser(
+ ptr_id, [this, &worklist, &ptr_id, handle_load,
+ &entry_function_ids](Instruction* user) {
+ BasicBlock* block = context()->get_instr_block(user);
+ if (block == nullptr ||
+ entry_function_ids.find(block->GetParent()->result_id()) ==
+ entry_function_ids.end()) {
+ return true;
+ }
+
+ if (user->opcode() == SpvOpAccessChain ||
+ user->opcode() == SpvOpInBoundsAccessChain ||
+ user->opcode() == SpvOpPtrAccessChain ||
+ user->opcode() == SpvOpInBoundsPtrAccessChain ||
+ user->opcode() == SpvOpCopyObject) {
+ if (ptr_id == user->GetSingleWordInOperand(0))
+ worklist.push_back(user->result_id());
+ return true;
+ }
+
+ if (user->opcode() != SpvOpLoad) {
+ return true;
+ }
+
+ return handle_load(user);
+ });
+ if (finish_traversal) return false;
+ }
+ return true;
+}
+
+void SpreadVolatileSemantics::SetVolatileForLoadsInEntries(
+ Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
+ // Set Volatile memory operand for all load instructions if they do not have
+ // it.
+ VisitLoadsOfPointersToVariableInEntries(
+ var->result_id(),
+ [](Instruction* load) {
+ if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
+ load->AddOperand(
+ {SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}});
+ return true;
+ }
+ uint32_t memory_operands =
+ load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
+ memory_operands |= SpvMemoryAccessVolatileMask;
+ load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
+ return true;
+ },
+ entry_function_ids);
+}
+
+bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
+ uint32_t var_id, SpvExecutionModel execution_model) {
+ analysis::DecorationManager* decoration_manager =
+ context()->get_decoration_mgr();
+ if (execution_model == SpvExecutionModelFragment) {
+ return get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 6) &&
+ HasBuiltinDecoration(decoration_manager, var_id,
+ SpvBuiltInHelperInvocation);
+ }
+
+ if (execution_model == SpvExecutionModelIntersectionKHR ||
+ execution_model == SpvExecutionModelIntersectionNV) {
+ if (HasBuiltinDecoration(decoration_manager, var_id,
+ SpvBuiltInRayTmaxKHR)) {
+ return true;
+ }
+ }
+
+ switch (execution_model) {
+ case SpvExecutionModelRayGenerationKHR:
+ case SpvExecutionModelClosestHitKHR:
+ case SpvExecutionModelMissKHR:
+ case SpvExecutionModelCallableKHR:
+ case SpvExecutionModelIntersectionKHR:
+ return HasBuiltinForRayTracingVolatileSemantics(decoration_manager,
+ var_id);
+ default:
+ return false;
+ }
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/spread_volatile_semantics.h b/source/opt/spread_volatile_semantics.h
new file mode 100644
index 00000000..531a21d5
--- /dev/null
+++ b/source/opt/spread_volatile_semantics.h
@@ -0,0 +1,117 @@
+// Copyright (c) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_SPREAD_VOLATILE_SEMANTICS_H_
+#define SOURCE_OPT_SPREAD_VOLATILE_SEMANTICS_H_
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class SpreadVolatileSemantics : public Pass {
+ public:
+ SpreadVolatileSemantics() {}
+
+ const char* name() const override { return "spread-volatile-semantics"; }
+
+ Status Process() override;
+
+ IRContext::Analysis GetPreservedAnalyses() override {
+ return IRContext::kAnalysisDefUse | IRContext::kAnalysisDecorations |
+ IRContext::kAnalysisInstrToBlockMapping;
+ }
+
+ private:
+ // Returns true if it does not have an execution model. Linkage shaders do not
+ // have an execution model.
+ bool HasNoExecutionModel() {
+ return get_module()->entry_points().empty() &&
+ context()->get_feature_mgr()->HasCapability(SpvCapabilityLinkage);
+ }
+
+ // Iterates interface variables and spreads the Volatile semantics if it has
+ // load instructions for the Volatile semantics.
+ Pass::Status SpreadVolatileSemanticsToVariables(
+ const bool is_vk_memory_model_enabled);
+
+ // Returns whether |var_id| is the result id of a target builtin variable for
+ // the volatile semantics for |execution_model| based on the Vulkan spec
+ // VUID-StandaloneSpirv-VulkanMemoryModel-04678 or
+ // VUID-StandaloneSpirv-VulkanMemoryModel-04679.
+ bool IsTargetForVolatileSemantics(uint32_t var_id,
+ SpvExecutionModel execution_model);
+
+ // Collects interface variables that need the volatile semantics.
+ // |is_vk_memory_model_enabled| is true if VulkanMemoryModel capability is
+ // enabled.
+ void CollectTargetsForVolatileSemantics(
+ const bool is_vk_memory_model_enabled);
+
+ // Reports an error if an interface variable is used by two entry points and
+ // it needs the Volatile decoration for one but not for another. Returns true
+ // if the error must be reported.
+ bool HasInterfaceInConflictOfVolatileSemantics();
+
+ // Returns whether the variable whose result is |var_id| is used by a
+ // non-volatile load or a pointer to it is used by a non-volatile load in
+ // |entry_point| or not.
+ bool IsTargetUsedByNonVolatileLoadInEntryPoint(uint32_t var_id,
+ Instruction* entry_point);
+
+ // Visits load instructions of pointers to variable whose result id is
+ // |var_id| if the load instructions are in entry points whose
+ // function id is one of |entry_function_ids|. |handle_load| is a function to
+ // do some actions for the load instructions. Finishes the traversal and
+ // returns false if |handle_load| returns false for a load instruction.
+ // Otherwise, returns true after running |handle_load| for all the load
+ // instructions.
+ bool VisitLoadsOfPointersToVariableInEntries(
+ uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
+ const std::unordered_set<uint32_t>& entry_function_ids);
+
+ // Sets Memory Operands of OpLoad instructions that load |var| or pointers
+ // of |var| as Volatile if the function id of the OpLoad instruction is
+ // included in |entry_function_ids|.
+ void SetVolatileForLoadsInEntries(
+ Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids);
+
+ // Adds OpDecorate Volatile for |var| if it does not exist.
+ void DecorateVarWithVolatile(Instruction* var);
+
+ // Returns a set of entry function ids to spread the volatile semantics for
+ // the variable with the result id |var_id|.
+ std::unordered_set<uint32_t> EntryFunctionsToSpreadVolatileSemanticsForVar(
+ uint32_t var_id) {
+ auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id);
+ if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) return {};
+ return itr->second;
+ }
+
+ // Specifies that we have to spread the volatile semantics for the
+ // variable with the result id |var_id| for the entry point |entry_point|.
+ void MarkVolatileSemanticsForVariable(uint32_t var_id,
+ Instruction* entry_point);
+
+ // Result ids of variables to entry function ids for the volatile semantics
+ // spread.
+ std::unordered_map<uint32_t, std::unordered_set<uint32_t>>
+ var_ids_to_entry_fn_for_volatile_semantics_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // SOURCE_OPT_SPREAD_VOLATILE_SEMANTICS_H_
diff --git a/source/opt/strength_reduction_pass.h b/source/opt/strength_reduction_pass.h
index 8dfeb307..1cbbbcc6 100644
--- a/source/opt/strength_reduction_pass.h
+++ b/source/opt/strength_reduction_pass.h
@@ -34,7 +34,7 @@ class StrengthReductionPass : public Pass {
// Returns true if something changed.
bool ReplaceMultiplyByPowerOf2(BasicBlock::iterator*);
- // Scan the types and constants in the module looking for the the integer
+ // Scan the types and constants in the module looking for the integer
// types that we are
// interested in. The shift operation needs a small unsigned integer. We
// need to find
diff --git a/source/opt/strip_debug_info_pass.cpp b/source/opt/strip_debug_info_pass.cpp
index c86ce578..6a0ebf24 100644
--- a/source/opt/strip_debug_info_pass.cpp
+++ b/source/opt/strip_debug_info_pass.cpp
@@ -14,6 +14,7 @@
#include "source/opt/strip_debug_info_pass.h"
#include "source/opt/ir_context.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -21,9 +22,8 @@ namespace opt {
Pass::Status StripDebugInfoPass::Process() {
bool uses_non_semantic_info = false;
for (auto& inst : context()->module()->extensions()) {
- const char* ext_name =
- reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]);
- if (0 == std::strcmp(ext_name, "SPV_KHR_non_semantic_info")) {
+ const std::string ext_name = inst.GetInOperand(0).AsString();
+ if (ext_name == "SPV_KHR_non_semantic_info") {
uses_non_semantic_info = true;
}
}
@@ -46,9 +46,10 @@ Pass::Status StripDebugInfoPass::Process() {
if (use->opcode() == SpvOpExtInst) {
auto ext_inst_set =
def_use->GetDef(use->GetSingleWordInOperand(0u));
- const char* extension_name = reinterpret_cast<const char*>(
- &ext_inst_set->GetInOperand(0).words[0]);
- if (0 == std::strncmp(extension_name, "NonSemantic.", 12)) {
+ const std::string extension_name =
+ ext_inst_set->GetInOperand(0).AsString();
+ if (spvtools::utils::starts_with(extension_name,
+ "NonSemantic.")) {
// found a non-semantic use, return false as we cannot
// remove this OpString
return false;
diff --git a/source/opt/strip_reflect_info_pass.cpp b/source/opt/strip_nonsemantic_info_pass.cpp
index 8b0f2db7..cd1fbb63 100644
--- a/source/opt/strip_reflect_info_pass.cpp
+++ b/source/opt/strip_nonsemantic_info_pass.cpp
@@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "source/opt/strip_reflect_info_pass.h"
+#include "source/opt/strip_nonsemantic_info_pass.h"
#include <cstring>
#include <vector>
#include "source/opt/instruction.h"
#include "source/opt/ir_context.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
-Pass::Status StripReflectInfoPass::Process() {
+Pass::Status StripNonSemanticInfoPass::Process() {
bool modified = false;
std::vector<Instruction*> to_remove;
@@ -32,7 +33,8 @@ Pass::Status StripReflectInfoPass::Process() {
for (auto& inst : context()->module()->annotations()) {
switch (inst.opcode()) {
case SpvOpDecorateStringGOOGLE:
- if (inst.GetSingleWordInOperand(1) == SpvDecorationHlslSemanticGOOGLE) {
+ if (inst.GetSingleWordInOperand(1) == SpvDecorationHlslSemanticGOOGLE ||
+ inst.GetSingleWordInOperand(1) == SpvDecorationUserTypeGOOGLE) {
to_remove.push_back(&inst);
} else {
other_uses_for_decorate_string = true;
@@ -40,7 +42,8 @@ Pass::Status StripReflectInfoPass::Process() {
break;
case SpvOpMemberDecorateStringGOOGLE:
- if (inst.GetSingleWordInOperand(2) == SpvDecorationHlslSemanticGOOGLE) {
+ if (inst.GetSingleWordInOperand(2) == SpvDecorationHlslSemanticGOOGLE ||
+ inst.GetSingleWordInOperand(2) == SpvDecorationUserTypeGOOGLE) {
to_remove.push_back(&inst);
} else {
other_uses_for_decorate_string = true;
@@ -60,33 +63,26 @@ Pass::Status StripReflectInfoPass::Process() {
}
for (auto& inst : context()->module()->extensions()) {
- const char* ext_name =
- reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]);
- if (0 == std::strcmp(ext_name, "SPV_GOOGLE_hlsl_functionality1")) {
+ const std::string ext_name = inst.GetInOperand(0).AsString();
+ if (ext_name == "SPV_GOOGLE_hlsl_functionality1") {
+ to_remove.push_back(&inst);
+ } else if (ext_name == "SPV_GOOGLE_user_type") {
to_remove.push_back(&inst);
} else if (!other_uses_for_decorate_string &&
- 0 == std::strcmp(ext_name, "SPV_GOOGLE_decorate_string")) {
+ ext_name == "SPV_GOOGLE_decorate_string") {
to_remove.push_back(&inst);
- } else if (0 == std::strcmp(ext_name, "SPV_KHR_non_semantic_info")) {
+ } else if (ext_name == "SPV_KHR_non_semantic_info") {
to_remove.push_back(&inst);
}
}
- // clear all debug data now if it hasn't been cleared already, to remove any
- // remaining OpString that may have been referenced by non-semantic extinsts
- for (auto& dbg : context()->debugs1()) to_remove.push_back(&dbg);
- for (auto& dbg : context()->debugs2()) to_remove.push_back(&dbg);
- for (auto& dbg : context()->debugs3()) to_remove.push_back(&dbg);
- for (auto& dbg : context()->ext_inst_debuginfo()) to_remove.push_back(&dbg);
-
// remove any extended inst imports that are non semantic
std::unordered_set<uint32_t> non_semantic_sets;
for (auto& inst : context()->module()->ext_inst_imports()) {
assert(inst.opcode() == SpvOpExtInstImport &&
"Expecting an import of an extension's instruction set.");
- const char* extension_name =
- reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]);
- if (0 == std::strncmp(extension_name, "NonSemantic.", 12)) {
+ const std::string extension_name = inst.GetInOperand(0).AsString();
+ if (spvtools::utils::starts_with(extension_name, "NonSemantic.")) {
non_semantic_sets.insert(inst.result_id());
to_remove.push_back(&inst);
}
@@ -103,19 +99,10 @@ Pass::Status StripReflectInfoPass::Process() {
to_remove.push_back(inst);
}
}
- });
+ },
+ true);
}
- // OpName must come first, since they may refer to other debug instructions.
- // If they are after the instructions that refer to, then they will be killed
- // when that instruction is killed, which will lead to a double kill.
- std::sort(to_remove.begin(), to_remove.end(),
- [](Instruction* lhs, Instruction* rhs) -> bool {
- if (lhs->opcode() == SpvOpName && rhs->opcode() != SpvOpName)
- return true;
- return false;
- });
-
for (auto* inst : to_remove) {
modified = true;
context()->KillInst(inst);
diff --git a/source/opt/strip_reflect_info_pass.h b/source/opt/strip_nonsemantic_info_pass.h
index 4e1999ed..ff4e2e1d 100644
--- a/source/opt/strip_reflect_info_pass.h
+++ b/source/opt/strip_nonsemantic_info_pass.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_
-#define SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_
+#ifndef SOURCE_OPT_STRIP_NONSEMANTIC_INFO_PASS_H_
+#define SOURCE_OPT_STRIP_NONSEMANTIC_INFO_PASS_H_
#include "source/opt/ir_context.h"
#include "source/opt/module.h"
@@ -23,9 +23,9 @@ namespace spvtools {
namespace opt {
// See optimizer.hpp for documentation.
-class StripReflectInfoPass : public Pass {
+class StripNonSemanticInfoPass : public Pass {
public:
- const char* name() const override { return "strip-reflect"; }
+ const char* name() const override { return "strip-nonsemantic"; }
Status Process() override;
// Return the mask of preserved Analyses.
@@ -41,4 +41,4 @@ class StripReflectInfoPass : public Pass {
} // namespace opt
} // namespace spvtools
-#endif // SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_
+#endif // SOURCE_OPT_STRIP_NONSEMANTIC_INFO_PASS_H_
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index 7935ad33..a0006f55 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -23,6 +23,7 @@
#include "source/opt/log.h"
#include "source/opt/reflect.h"
#include "source/util/make_unique.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -234,6 +235,7 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
DefineParameterlessCase(PipeStorage);
DefineParameterlessCase(NamedBarrier);
DefineParameterlessCase(AccelerationStructureNV);
+ DefineParameterlessCase(RayQueryKHR);
#undef DefineParameterlessCase
case Type::kInteger:
typeInst = MakeUnique<Instruction>(
@@ -349,11 +351,8 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
}
case Type::kOpaque: {
const Opaque* opaque = type->AsOpaque();
- size_t size = opaque->name().size();
// Convert to null-terminated packed UTF-8 string.
- std::vector<uint32_t> words(size / 4 + 1, 0);
- char* dst = reinterpret_cast<char*>(words.data());
- strncpy(dst, opaque->name().c_str(), size);
+ std::vector<uint32_t> words = spvtools::utils::MakeVector(opaque->name());
typeInst = MakeUnique<Instruction>(
context(), SpvOpTypeOpaque, 0, id,
std::initializer_list<Operand>{
@@ -529,6 +528,7 @@ Type* TypeManager::RebuildType(const Type& type) {
DefineNoSubtypeCase(PipeStorage);
DefineNoSubtypeCase(NamedBarrier);
DefineNoSubtypeCase(AccelerationStructureNV);
+ DefineNoSubtypeCase(RayQueryKHR);
#undef DefineNoSubtypeCase
case Type::kVector: {
const Vector* vec_ty = type.AsVector();
@@ -781,8 +781,7 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
}
} break;
case SpvOpTypeOpaque: {
- const uint32_t* data = inst.GetInOperand(0).words.data();
- type = new Opaque(reinterpret_cast<const char*>(data));
+ type = new Opaque(inst.GetInOperand(0).AsString());
} break;
case SpvOpTypePointer: {
uint32_t pointee_type_id = inst.GetSingleWordInOperand(1);
diff --git a/source/opt/type_manager.h b/source/opt/type_manager.h
index ce9d83d4..72e37f48 100644
--- a/source/opt/type_manager.h
+++ b/source/opt/type_manager.h
@@ -160,6 +160,13 @@ class TypeManager {
uint32_t GetFloatTypeId() { return GetTypeInstruction(GetFloatType()); }
+ Type* GetDoubleType() {
+ Float float_type(64);
+ return GetRegisteredType(&float_type);
+ }
+
+ uint32_t GetDoubleTypeId() { return GetTypeInstruction(GetDoubleType()); }
+
Type* GetUIntVectorType(uint32_t size) {
Vector vec_type(GetUIntType(), size);
return GetRegisteredType(&vec_type);
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index b1eb3a50..ebbdc367 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -21,6 +21,7 @@
#include <string>
#include <unordered_set>
+#include "source/util/hash_combine.h"
#include "source/util/make_unique.h"
#include "spirv/unified1/spirv.h"
@@ -28,6 +29,7 @@ namespace spvtools {
namespace opt {
namespace analysis {
+using spvtools::utils::hash_combine;
using U32VecVec = std::vector<std::vector<uint32_t>>;
namespace {
@@ -182,23 +184,26 @@ bool Type::operator==(const Type& other) const {
}
}
-void Type::GetHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
- if (!seen->insert(this).second) {
- return;
+size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
+ // Linear search through a dense, cache coherent vector is faster than O(log
+ // n) search in a complex data structure (eg std::set) for the generally small
+ // number of nodes. It also skips the overhead of an new/delete per Type
+ // (when inserting/removing from a set).
+ if (std::find(seen->begin(), seen->end(), this) != seen->end()) {
+ return hash;
}
- words->push_back(kind_);
+ seen->push_back(this);
+
+ hash = hash_combine(hash, uint32_t(kind_));
for (const auto& d : decorations_) {
- for (auto w : d) {
- words->push_back(w);
- }
+ hash = hash_combine(hash, d);
}
switch (kind_) {
-#define DeclareKindCase(type) \
- case k##type: \
- As##type()->GetExtraHashWords(words, seen); \
+#define DeclareKindCase(type) \
+ case k##type: \
+ hash = As##type()->ComputeExtraStateHash(hash, seen); \
break
DeclareKindCase(Void);
DeclareKindCase(Bool);
@@ -232,18 +237,13 @@ void Type::GetHashWords(std::vector<uint32_t>* words,
break;
}
- seen->erase(this);
+ seen->pop_back();
+ return hash;
}
size_t Type::HashValue() const {
- std::u32string h;
- std::vector<uint32_t> words;
- GetHashWords(&words);
- for (auto w : words) {
- h.push_back(w);
- }
-
- return std::hash<std::u32string>()(h);
+ SeenTypes seen;
+ return ComputeHashValue(0, &seen);
}
bool Integer::IsSameImpl(const Type* that, IsSameCache*) const {
@@ -258,10 +258,8 @@ std::string Integer::str() const {
return oss.str();
}
-void Integer::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>*) const {
- words->push_back(width_);
- words->push_back(signed_);
+size_t Integer::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
+ return hash_combine(hash, width_, signed_);
}
bool Float::IsSameImpl(const Type* that, IsSameCache*) const {
@@ -275,9 +273,8 @@ std::string Float::str() const {
return oss.str();
}
-void Float::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>*) const {
- words->push_back(width_);
+size_t Float::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
+ return hash_combine(hash, width_);
}
Vector::Vector(const Type* type, uint32_t count)
@@ -299,10 +296,11 @@ std::string Vector::str() const {
return oss.str();
}
-void Vector::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
- element_type_->GetHashWords(words, seen);
- words->push_back(count_);
+size_t Vector::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+ // prefer form that doesn't require push/pop from stack: add state and
+ // make tail call.
+ hash = hash_combine(hash, count_);
+ return element_type_->ComputeHashValue(hash, seen);
}
Matrix::Matrix(const Type* type, uint32_t count)
@@ -324,10 +322,9 @@ std::string Matrix::str() const {
return oss.str();
}
-void Matrix::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
- element_type_->GetHashWords(words, seen);
- words->push_back(count_);
+size_t Matrix::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+ hash = hash_combine(hash, count_);
+ return element_type_->ComputeHashValue(hash, seen);
}
Image::Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample,
@@ -362,16 +359,10 @@ std::string Image::str() const {
return oss.str();
}
-void Image::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
- sampled_type_->GetHashWords(words, seen);
- words->push_back(dim_);
- words->push_back(depth_);
- words->push_back(arrayed_);
- words->push_back(ms_);
- words->push_back(sampled_);
- words->push_back(format_);
- words->push_back(access_qualifier_);
+size_t Image::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+ hash = hash_combine(hash, uint32_t(dim_), depth_, arrayed_, ms_, sampled_,
+ uint32_t(format_), uint32_t(access_qualifier_));
+ return sampled_type_->ComputeHashValue(hash, seen);
}
bool SampledImage::IsSameImpl(const Type* that, IsSameCache* seen) const {
@@ -387,9 +378,8 @@ std::string SampledImage::str() const {
return oss.str();
}
-void SampledImage::GetExtraHashWords(
- std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const {
- image_type_->GetHashWords(words, seen);
+size_t SampledImage::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+ return image_type_->ComputeHashValue(hash, seen);
}
Array::Array(const Type* type, const Array::LengthInfo& length_info_arg)
@@ -422,16 +412,19 @@ std::string Array::str() const {
return oss.str();
}
-void Array::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
- element_type_->GetHashWords(words, seen);
- // This should mirror the logic in IsSameImpl
- words->insert(words->end(), length_info_.words.begin(),
- length_info_.words.end());
+size_t Array::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+ hash = hash_combine(hash, length_info_.words);
+ return element_type_->ComputeHashValue(hash, seen);
}
void Array::ReplaceElementType(const Type* type) { element_type_ = type; }
+Array::LengthInfo Array::GetConstantLengthInfo(uint32_t const_id,
+ uint32_t length) const {
+ std::vector<uint32_t> extra_words{LengthInfo::Case::kConstant, length};
+ return {const_id, extra_words};
+}
+
RuntimeArray::RuntimeArray(const Type* type)
: Type(kRuntimeArray), element_type_(type) {
assert(!type->AsVoid());
@@ -450,9 +443,8 @@ std::string RuntimeArray::str() const {
return oss.str();
}
-void RuntimeArray::GetExtraHashWords(
- std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const {
- element_type_->GetHashWords(words, seen);
+size_t RuntimeArray::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+ return element_type_->ComputeHashValue(hash, seen);
}
void RuntimeArray::ReplaceElementType(const Type* type) {
@@ -509,19 +501,14 @@ std::string Struct::str() const {
return oss.str();
}
-void Struct::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
+size_t Struct::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
for (auto* t : element_types_) {
- t->GetHashWords(words, seen);
+ hash = t->ComputeHashValue(hash, seen);
}
for (const auto& pair : element_decorations_) {
- words->push_back(pair.first);
- for (const auto& d : pair.second) {
- for (auto w : d) {
- words->push_back(w);
- }
- }
+ hash = hash_combine(hash, pair.first, pair.second);
}
+ return hash;
}
bool Opaque::IsSameImpl(const Type* that, IsSameCache*) const {
@@ -536,11 +523,8 @@ std::string Opaque::str() const {
return oss.str();
}
-void Opaque::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>*) const {
- for (auto c : name_) {
- words->push_back(static_cast<char32_t>(c));
- }
+size_t Opaque::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
+ return hash_combine(hash, name_);
}
Pointer::Pointer(const Type* type, SpvStorageClass sc)
@@ -569,10 +553,9 @@ std::string Pointer::str() const {
return os.str();
}
-void Pointer::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
- pointee_type_->GetHashWords(words, seen);
- words->push_back(storage_class_);
+size_t Pointer::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
+ hash = hash_combine(hash, uint32_t(storage_class_));
+ return pointee_type_->ComputeHashValue(hash, seen);
}
void Pointer::SetPointeeType(const Type* type) { pointee_type_ = type; }
@@ -606,12 +589,11 @@ std::string Function::str() const {
return oss.str();
}
-void Function::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const {
- return_type_->GetHashWords(words, seen);
+size_t Function::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
for (const auto* t : param_types_) {
- t->GetHashWords(words, seen);
+ hash = t->ComputeHashValue(hash, seen);
}
+ return return_type_->ComputeHashValue(hash, seen);
}
void Function::SetReturnType(const Type* type) { return_type_ = type; }
@@ -628,9 +610,8 @@ std::string Pipe::str() const {
return oss.str();
}
-void Pipe::GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>*) const {
- words->push_back(access_qualifier_);
+size_t Pipe::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
+ return hash_combine(hash, uint32_t(access_qualifier_));
}
bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const {
@@ -653,11 +634,11 @@ std::string ForwardPointer::str() const {
return oss.str();
}
-void ForwardPointer::GetExtraHashWords(
- std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const {
- words->push_back(target_id_);
- words->push_back(storage_class_);
- if (pointer_) pointer_->GetHashWords(words, seen);
+size_t ForwardPointer::ComputeExtraStateHash(size_t hash,
+ SeenTypes* seen) const {
+ hash = hash_combine(hash, target_id_, uint32_t(storage_class_));
+ if (pointer_) hash = pointer_->ComputeHashValue(hash, seen);
+ return hash;
}
CooperativeMatrixNV::CooperativeMatrixNV(const Type* type, const uint32_t scope,
@@ -681,12 +662,10 @@ std::string CooperativeMatrixNV::str() const {
return oss.str();
}
-void CooperativeMatrixNV::GetExtraHashWords(
- std::vector<uint32_t>* words, std::unordered_set<const Type*>* pSet) const {
- component_type_->GetHashWords(words, pSet);
- words->push_back(scope_id_);
- words->push_back(rows_id_);
- words->push_back(columns_id_);
+size_t CooperativeMatrixNV::ComputeExtraStateHash(size_t hash,
+ SeenTypes* seen) const {
+ hash = hash_combine(hash, scope_id_, rows_id_, columns_id_);
+ return component_type_->ComputeHashValue(hash, seen);
}
bool CooperativeMatrixNV::IsSameImpl(const Type* that,
diff --git a/source/opt/types.h b/source/opt/types.h
index 9ecd41a6..f5a4a6be 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -28,6 +28,7 @@
#include "source/latest_version_spirv_header.h"
#include "source/opt/instruction.h"
+#include "source/util/small_vector.h"
#include "spirv-tools/libspirv.h"
namespace spvtools {
@@ -67,6 +68,8 @@ class Type {
public:
typedef std::set<std::pair<const Pointer*, const Pointer*>> IsSameCache;
+ using SeenTypes = spvtools::utils::SmallVector<const Type*, 8>;
+
// Available subtypes.
//
// When adding a new derived class of Type, please add an entry to the enum.
@@ -96,7 +99,8 @@ class Type {
kNamedBarrier,
kAccelerationStructureNV,
kCooperativeMatrixNV,
- kRayQueryKHR
+ kRayQueryKHR,
+ kLast
};
Type(Kind k) : kind_(k) {}
@@ -154,21 +158,7 @@ class Type {
// Returns the hash value of this type.
size_t HashValue() const;
- // Adds the necessary words to compute a hash value of this type to |words|.
- void GetHashWords(std::vector<uint32_t>* words) const {
- std::unordered_set<const Type*> seen;
- GetHashWords(words, &seen);
- }
-
- // Adds the necessary words to compute a hash value of this type to |words|.
- void GetHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* seen) const;
-
- // Adds necessary extra words for a subtype to calculate a hash value into
- // |words|.
- virtual void GetExtraHashWords(
- std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const = 0;
+ size_t ComputeHashValue(size_t hash, SeenTypes* seen) const;
// A bunch of methods for casting this type to a given type. Returns this if the
// cast can be done, nullptr otherwise.
@@ -204,6 +194,10 @@ class Type {
DeclareCastMethod(RayQueryKHR)
#undef DeclareCastMethod
+protected:
+ // Add any type-specific state to |hash| and returns new hash.
+ virtual size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const = 0;
+
protected:
// Decorations attached to this type. Each decoration is encoded as a vector
// of uint32_t numbers. The first uint32_t number is the decoration value,
@@ -232,8 +226,7 @@ class Integer : public Type {
uint32_t width() const { return width_; }
bool IsSigned() const { return signed_; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -253,8 +246,7 @@ class Float : public Type {
const Float* AsFloat() const override { return this; }
uint32_t width() const { return width_; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -274,8 +266,7 @@ class Vector : public Type {
Vector* AsVector() override { return this; }
const Vector* AsVector() const override { return this; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -296,8 +287,7 @@ class Matrix : public Type {
Matrix* AsMatrix() override { return this; }
const Matrix* AsMatrix() const override { return this; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -327,8 +317,7 @@ class Image : public Type {
SpvImageFormat format() const { return format_; }
SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -355,8 +344,7 @@ class SampledImage : public Type {
const Type* image_type() const { return image_type_; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -399,10 +387,10 @@ class Array : public Type {
Array* AsArray() override { return this; }
const Array* AsArray() const override { return this; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
void ReplaceElementType(const Type* element_type);
+ LengthInfo GetConstantLengthInfo(uint32_t const_id, uint32_t length) const;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -422,8 +410,7 @@ class RuntimeArray : public Type {
RuntimeArray* AsRuntimeArray() override { return this; }
const RuntimeArray* AsRuntimeArray() const override { return this; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
void ReplaceElementType(const Type* element_type);
@@ -459,8 +446,7 @@ class Struct : public Type {
Struct* AsStruct() override { return this; }
const Struct* AsStruct() const override { return this; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -491,8 +477,7 @@ class Opaque : public Type {
const std::string& name() const { return name_; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -512,8 +497,7 @@ class Pointer : public Type {
Pointer* AsPointer() override { return this; }
const Pointer* AsPointer() const override { return this; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
void SetPointeeType(const Type* type);
@@ -539,8 +523,7 @@ class Function : public Type {
const std::vector<const Type*>& param_types() const { return param_types_; }
std::vector<const Type*>& param_types() { return param_types_; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>*) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
void SetReturnType(const Type* type);
@@ -564,8 +547,7 @@ class Pipe : public Type {
SpvAccessQualifier access_qualifier() const { return access_qualifier_; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -592,8 +574,7 @@ class ForwardPointer : public Type {
ForwardPointer* AsForwardPointer() override { return this; }
const ForwardPointer* AsForwardPointer() const override { return this; }
- void GetExtraHashWords(std::vector<uint32_t>* words,
- std::unordered_set<const Type*>* pSet) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;
@@ -616,8 +597,7 @@ class CooperativeMatrixNV : public Type {
return this;
}
- void GetExtraHashWords(std::vector<uint32_t>*,
- std::unordered_set<const Type*>*) const override;
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
const Type* component_type() const { return component_type_; }
uint32_t scope_id() const { return scope_id_; }
@@ -633,24 +613,25 @@ class CooperativeMatrixNV : public Type {
const uint32_t columns_id_;
};
-#define DefineParameterlessType(type, name) \
- class type : public Type { \
- public: \
- type() : Type(k##type) {} \
- type(const type&) = default; \
- \
- std::string str() const override { return #name; } \
- \
- type* As##type() override { return this; } \
- const type* As##type() const override { return this; } \
- \
- void GetExtraHashWords(std::vector<uint32_t>*, \
- std::unordered_set<const Type*>*) const override {} \
- \
- private: \
- bool IsSameImpl(const Type* that, IsSameCache*) const override { \
- return that->As##type() && HasSameDecorations(that); \
- } \
+#define DefineParameterlessType(type, name) \
+ class type : public Type { \
+ public: \
+ type() : Type(k##type) {} \
+ type(const type&) = default; \
+ \
+ std::string str() const override { return #name; } \
+ \
+ type* As##type() override { return this; } \
+ const type* As##type() const override { return this; } \
+ \
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes*) const override { \
+ return hash; \
+ } \
+ \
+ private: \
+ bool IsSameImpl(const Type* that, IsSameCache*) const override { \
+ return that->As##type() && HasSameDecorations(that); \
+ } \
}
DefineParameterlessType(Void, void);
DefineParameterlessType(Bool, bool);
diff --git a/source/opt/unify_const_pass.cpp b/source/opt/unify_const_pass.cpp
index 227fd61d..6bfa11a5 100644
--- a/source/opt/unify_const_pass.cpp
+++ b/source/opt/unify_const_pass.cpp
@@ -151,7 +151,7 @@ Pass::Status UnifyConstantPass::Process() {
// 'SpecId' decoration and all of them should be treated as unique.
// 'SpecId' is not applicable to SpecConstants defined with
// OpSpecConstant{Op|Composite}, their values are not necessary to be
- // unique. When all the operands/compoents are the same between two
+ // unique. When all the operands/components are the same between two
// OpSpecConstant{Op|Composite} results, their result values must be the
// same so are unifiable.
case SpvOp::SpvOpSpecConstantOp:
diff --git a/source/opt/upgrade_memory_model.cpp b/source/opt/upgrade_memory_model.cpp
index ab252059..9d6a5bce 100644
--- a/source/opt/upgrade_memory_model.cpp
+++ b/source/opt/upgrade_memory_model.cpp
@@ -20,6 +20,7 @@
#include "source/opt/ir_context.h"
#include "source/spirv_constant.h"
#include "source/util/make_unique.h"
+#include "source/util/string_utils.h"
namespace spvtools {
namespace opt {
@@ -58,9 +59,7 @@ void UpgradeMemoryModel::UpgradeMemoryModelInstruction() {
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityVulkanMemoryModelKHR}}}));
const std::string extension = "SPV_KHR_vulkan_memory_model";
- std::vector<uint32_t> words(extension.size() / 4 + 1, 0);
- char* dst = reinterpret_cast<char*>(words.data());
- strncpy(dst, extension.c_str(), extension.size());
+ std::vector<uint32_t> words = spvtools::utils::MakeVector(extension);
context()->AddExtension(
MakeUnique<Instruction>(context(), SpvOpExtension, 0, 0,
std::initializer_list<Operand>{
@@ -85,8 +84,7 @@ void UpgradeMemoryModel::UpgradeInstructions() {
if (ext_inst == GLSLstd450Modf || ext_inst == GLSLstd450Frexp) {
auto import =
get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
- if (reinterpret_cast<char*>(import->GetInOperand(0u).words.data()) ==
- std::string("GLSL.std.450")) {
+ if (import->GetInOperand(0u).AsString() == "GLSL.std.450") {
UpgradeExtInst(inst);
}
}
diff --git a/source/opt/vector_dce.h b/source/opt/vector_dce.h
index 4d30b926..a55bda69 100644
--- a/source/opt/vector_dce.h
+++ b/source/opt/vector_dce.h
@@ -73,7 +73,7 @@ class VectorDCE : public MemPass {
bool RewriteInstructions(Function* function,
const LiveComponentMap& live_components);
- // Makrs all DebugValue instructions that use |composite| for their values as
+ // Makes all DebugValue instructions that use |composite| for their values as
// dead instructions by putting them into |dead_dbg_value|.
void MarkDebugValueUsesAsDead(Instruction* composite,
std::vector<Instruction*>* dead_dbg_value);