aboutsummaryrefslogtreecommitdiff
path: root/source/opt/folding_rules.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/opt/folding_rules.cpp')
-rw-r--r--source/opt/folding_rules.cpp361
1 files changed, 352 insertions, 9 deletions
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index 4904f186..3f10bd00 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -277,6 +277,11 @@ uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
uint32_t width = c->type()->AsFloat()->width();
assert(width == 32 || width == 64);
std::vector<uint32_t> words;
+
+ if (c->IsZero()) {
+ return 0;
+ }
+
if (width == 64) {
spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
if (!IsValidResult(result.getAsFloat())) return 0;
@@ -1430,6 +1435,132 @@ FoldingRule FactorAddMuls() {
};
}
+// Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
+void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
+ uint32_t ext =
+ inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+
+ if (ext == 0) {
+ inst->context()->AddExtInstImport("GLSL.std.450");
+ ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ assert(ext != 0 &&
+ "Could not add the GLSL.std.450 extended instruction set");
+ }
+
+ std::vector<Operand> operands;
+ operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
+ operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
+ operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
+ operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
+ operands.push_back({SPV_OPERAND_TYPE_ID, {a}});
+
+ inst->SetOpcode(SpvOpExtInst);
+ inst->SetInOperands(std::move(operands));
+}
+
+// Folds a multiple and add into an Fma.
+//
+// Cases:
+// (x * y) + a = Fma x y a
+// a + (x * y) = Fma x y a
+bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ assert(inst->opcode() == SpvOpFAdd);
+
+ if (!inst->IsFloatingPointFoldingAllowed()) {
+ return false;
+ }
+
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ for (int i = 0; i < 2; i++) {
+ uint32_t op_id = inst->GetSingleWordInOperand(i);
+ Instruction* op_inst = def_use_mgr->GetDef(op_id);
+
+ if (op_inst->opcode() != SpvOpFMul) {
+ continue;
+ }
+
+ if (!op_inst->IsFloatingPointFoldingAllowed()) {
+ continue;
+ }
+
+ uint32_t x = op_inst->GetSingleWordInOperand(0);
+ uint32_t y = op_inst->GetSingleWordInOperand(1);
+ uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
+ ReplaceWithFma(inst, x, y, a);
+ return true;
+ }
+ return false;
+}
+
+// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
+// negated if |negate_addition| is true, otherwise |x| gets negated.
+void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
+ uint32_t a, bool negate_addition) {
+ uint32_t ext =
+ sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+
+ if (ext == 0) {
+ sub->context()->AddExtInstImport("GLSL.std.450");
+ ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
+ assert(ext != 0 &&
+ "Could not add the GLSL.std.450 extended instruction set");
+ }
+
+ InstructionBuilder ir_builder(
+ sub->context(), sub,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), SpvOpFNegate,
+ negate_addition ? a : x);
+ uint32_t neg_op = neg->result_id(); // -a : -x
+
+ std::vector<Operand> operands;
+ operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
+ operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
+ operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
+ operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
+ operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});
+
+ sub->SetOpcode(SpvOpExtInst);
+ sub->SetInOperands(std::move(operands));
+}
+
+// Folds a multiply and subtract into an Fma and negation.
+//
+// Cases:
+// (x * y) - a = Fma x y -a
+// a - (x * y) = Fma -x y a
+bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
+ const std::vector<const analysis::Constant*>&) {
+ assert(sub->opcode() == SpvOpFSub);
+
+ if (!sub->IsFloatingPointFoldingAllowed()) {
+ return false;
+ }
+
+ analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
+ for (int i = 0; i < 2; i++) {
+ uint32_t op_id = sub->GetSingleWordInOperand(i);
+ Instruction* mul = def_use_mgr->GetDef(op_id);
+
+ if (mul->opcode() != SpvOpFMul) {
+ continue;
+ }
+
+ if (!mul->IsFloatingPointFoldingAllowed()) {
+ continue;
+ }
+
+ uint32_t x = mul->GetSingleWordInOperand(0);
+ uint32_t y = mul->GetSingleWordInOperand(1);
+ uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
+ ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
+ return true;
+ }
+ return false;
+}
+
FoldingRule IntMultipleBy1() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
@@ -1573,6 +1704,57 @@ bool CompositeConstructFeedingExtract(
return true;
}
+// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
+// OpCompositeExtract instruction, and returns the type of the final element
+// being accessed.
+const analysis::Type* GetElementType(uint32_t type_id,
+ Instruction::iterator start,
+ Instruction::iterator end,
+ const analysis::TypeManager* type_mgr) {
+ const analysis::Type* type = type_mgr->GetType(type_id);
+ for (auto index : make_range(std::move(start), std::move(end))) {
+ assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
+ index.words.size() == 1);
+ if (auto* array_type = type->AsArray()) {
+ type = array_type->element_type();
+ } else if (auto* matrix_type = type->AsMatrix()) {
+ type = matrix_type->element_type();
+ } else if (auto* struct_type = type->AsStruct()) {
+ type = struct_type->element_types()[index.words[0]];
+ } else {
+ type = nullptr;
+ }
+ }
+ return type;
+}
+
+// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
+// to index into a composite object, excluding the last index. The two
+// instructions must have the same opcode, and be either OpCompositeExtract or
+// OpCompositeInsert instructions.
+bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
+ assert(inst_1->opcode() == inst_2->opcode() &&
+ "Expecting the opcodes to be the same.");
+ assert((inst_1->opcode() == SpvOpCompositeInsert ||
+ inst_1->opcode() == SpvOpCompositeExtract) &&
+ "Instructions must be OpCompositeInsert or OpCompositeExtract.");
+
+ if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
+ return false;
+ }
+
+ uint32_t first_index_position =
+ (inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1);
+ for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
+ i++) {
+ if (inst_1->GetSingleWordInOperand(i) !=
+ inst_2->GetSingleWordInOperand(i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
// If the OpCompositeConstruct is simply putting back together elements that
// where extracted from the same source, we can simply reuse the source.
//
@@ -1595,19 +1777,24 @@ bool CompositeExtractFeedingConstruct(
// - extractions
// - extracting the same position they are inserting
// - all extract from the same id.
+ Instruction* first_element_inst = nullptr;
for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
const uint32_t element_id = inst->GetSingleWordInOperand(i);
Instruction* element_inst = def_use_mgr->GetDef(element_id);
+ if (first_element_inst == nullptr) {
+ first_element_inst = element_inst;
+ }
if (element_inst->opcode() != SpvOpCompositeExtract) {
return false;
}
- if (element_inst->NumInOperands() != 2) {
+ if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
return false;
}
- if (element_inst->GetSingleWordInOperand(1) != i) {
+ if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
+ 1) != i) {
return false;
}
@@ -1623,13 +1810,31 @@ bool CompositeExtractFeedingConstruct(
// The last check it to see that the object being extracted from is the
// correct type.
Instruction* original_inst = def_use_mgr->GetDef(original_id);
- if (original_inst->type_id() != inst->type_id()) {
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ const analysis::Type* original_type =
+ GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
+ first_element_inst->end() - 1, type_mgr);
+
+ if (original_type == nullptr) {
return false;
}
- // Simplify by using the original object.
- inst->SetOpcode(SpvOpCopyObject);
- inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+ if (inst->type_id() != type_mgr->GetId(original_type)) {
+ return false;
+ }
+
+ if (first_element_inst->NumInOperands() == 2) {
+ // Simplify by using the original object.
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+ return true;
+ }
+
+ // Copies the original id and all indexes except for the last to the new
+ // extract instruction.
+ inst->SetOpcode(SpvOpCompositeExtract);
+ inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
+ first_element_inst->end() - 1));
return true;
}
@@ -1833,6 +2038,139 @@ FoldingRule FMixFeedingExtract() {
};
}
+// Returns the number of elements in the composite type |type|. Returns 0 if
+// |type| is a scalar value.
+uint32_t GetNumberOfElements(const analysis::Type* type) {
+ if (auto* vector_type = type->AsVector()) {
+ return vector_type->element_count();
+ }
+ if (auto* matrix_type = type->AsMatrix()) {
+ return matrix_type->element_count();
+ }
+ if (auto* struct_type = type->AsStruct()) {
+ return static_cast<uint32_t>(struct_type->element_types().size());
+ }
+ if (auto* array_type = type->AsArray()) {
+ return array_type->length_info().words[0];
+ }
+ return 0;
+}
+
+// Returns a map with the set of values that were inserted into an object by
+// the chain of OpCompositeInsertInstruction starting with |inst|.
+// The map will map the index to the value inserted at that index.
+std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
+ analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
+ std::map<uint32_t, uint32_t> values_inserted;
+ Instruction* current_inst = inst;
+ while (current_inst->opcode() == SpvOpCompositeInsert) {
+ if (current_inst->NumInOperands() > inst->NumInOperands()) {
+ // This is the catch the case
+ // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
+ // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
+ // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
+ // In this case we cannot do a single construct to get the matrix.
+ uint32_t partially_inserted_element_index =
+ current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
+ if (values_inserted.count(partially_inserted_element_index) == 0)
+ return {};
+ }
+ if (HaveSameIndexesExceptForLast(inst, current_inst)) {
+ values_inserted.insert(
+ {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
+ 1),
+ current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
+ }
+ current_inst = def_use_mgr->GetDef(
+ current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
+ }
+ return values_inserted;
+}
+
+// Returns true of there is an entry in |values_inserted| for every element of
+// |Type|.
+bool DoInsertedValuesCoverEntireObject(
+ const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
+ uint32_t container_size = GetNumberOfElements(type);
+ if (container_size != values_inserted.size()) {
+ return false;
+ }
+
+ if (values_inserted.rbegin()->first >= container_size) {
+ return false;
+ }
+ return true;
+}
+
+// Returns the type of the element that immediately contains the element being
+// inserted by the OpCompositeInsert instruction |inst|.
+const analysis::Type* GetContainerType(Instruction* inst) {
+ assert(inst->opcode() == SpvOpCompositeInsert);
+ analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
+ return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
+ type_mgr);
+}
+
+// Returns an OpCompositeConstruct instruction that build an object with
+// |type_id| out of the values in |values_inserted|. Each value will be
+// placed at the index corresponding to the value. The new instruction will
+// be placed before |insert_before|.
+Instruction* BuildCompositeConstruct(
+ uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
+ Instruction* insert_before) {
+ InstructionBuilder ir_builder(
+ insert_before->context(), insert_before,
+ IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+ std::vector<uint32_t> ids_in_order;
+ for (auto it : values_inserted) {
+ ids_in_order.push_back(it.second);
+ }
+ Instruction* construct =
+ ir_builder.AddCompositeConstruct(type_id, ids_in_order);
+ return construct;
+}
+
+// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
+// object as |inst| with final index removed. If the resulting
+// OpCompositeInsert instruction would have no remaining indexes, the
+// instruction is replaced with an OpCopyObject instead.
+void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
+ if (inst->NumInOperands() == 3) {
+ inst->SetOpcode(SpvOpCopyObject);
+ inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
+ } else {
+ inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
+ inst->RemoveOperand(inst->NumOperands() - 1);
+ }
+}
+
+// Replaces a series of |OpCompositeInsert| instruction that cover the entire
+// object with an |OpCompositeConstruct|.
+bool CompositeInsertToCompositeConstruct(
+ IRContext* context, Instruction* inst,
+ const std::vector<const analysis::Constant*>&) {
+ assert(inst->opcode() == SpvOpCompositeInsert &&
+ "Wrong opcode. Should be OpCompositeInsert.");
+ if (inst->NumInOperands() < 3) return false;
+
+ std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
+ const analysis::Type* container_type = GetContainerType(inst);
+ if (container_type == nullptr) {
+ return false;
+ }
+
+ if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
+ return false;
+ }
+
+ analysis::TypeManager* type_mgr = context->get_type_mgr();
+ Instruction* construct = BuildCompositeConstruct(
+ type_mgr->GetId(container_type), values_inserted, inst);
+ InsertConstructedObject(inst, construct);
+ return true;
+}
+
FoldingRule RedundantPhi() {
// An OpPhi instruction where all values are the same or the result of the phi
// itself, can be replaced by the value itself.
@@ -2368,7 +2706,7 @@ FoldingRule VectorShuffleFeedingShuffle() {
// fold.
return false;
}
- } else {
+ } else if (component_index != undef_literal) {
if (new_feeder_id == 0) {
// First time through, save the id of the operand the element comes
// from.
@@ -2382,7 +2720,7 @@ FoldingRule VectorShuffleFeedingShuffle() {
component_index -= feeder_op0_length;
}
- if (!feeder_is_op0) {
+ if (!feeder_is_op0 && component_index != undef_literal) {
component_index += op0_length;
}
}
@@ -2410,7 +2748,8 @@ FoldingRule VectorShuffleFeedingShuffle() {
if (adjustment != 0) {
for (uint32_t i = 2; i < new_operands.size(); i++) {
- if (inst->GetSingleWordInOperand(i) >= op0_length) {
+ uint32_t operand = inst->GetSingleWordInOperand(i);
+ if (operand >= op0_length && operand != undef_literal) {
new_operands[i].words[0] -= adjustment;
}
}
@@ -2533,6 +2872,8 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
+ rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct);
+
rules_[SpvOpDot].push_back(DotProductDoingExtract());
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
@@ -2543,6 +2884,7 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
rules_[SpvOpFAdd].push_back(FactorAddMuls());
+ rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic);
rules_[SpvOpFDiv].push_back(RedundantFDiv());
rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
@@ -2563,6 +2905,7 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
+ rules_[SpvOpFSub].push_back(MergeMulSubArithmetic);
rules_[SpvOpIAdd].push_back(RedundantIAdd());
rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());