diff options
Diffstat (limited to 'source/opt/folding_rules.cpp')
-rw-r--r-- | source/opt/folding_rules.cpp | 361 |
1 files changed, 352 insertions, 9 deletions
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 4904f186..3f10bd00 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -277,6 +277,11 @@ uint32_t Reciprocal(analysis::ConstantManager* const_mgr, uint32_t width = c->type()->AsFloat()->width(); assert(width == 32 || width == 64); std::vector<uint32_t> words; + + if (c->IsZero()) { + return 0; + } + if (width == 64) { spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble()); if (!IsValidResult(result.getAsFloat())) return 0; @@ -1430,6 +1435,132 @@ FoldingRule FactorAddMuls() { }; } +// Replaces |inst| inplace with an FMA instruction |(x*y)+a|. +void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) { + uint32_t ext = + inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + + if (ext == 0) { + inst->context()->AddExtInstImport("GLSL.std.450"); + ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + assert(ext != 0 && + "Could not add the GLSL.std.450 extended instruction set"); + } + + std::vector<Operand> operands; + operands.push_back({SPV_OPERAND_TYPE_ID, {ext}}); + operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}}); + operands.push_back({SPV_OPERAND_TYPE_ID, {x}}); + operands.push_back({SPV_OPERAND_TYPE_ID, {y}}); + operands.push_back({SPV_OPERAND_TYPE_ID, {a}}); + + inst->SetOpcode(SpvOpExtInst); + inst->SetInOperands(std::move(operands)); +} + +// Folds a multiple and add into an Fma. +// +// Cases: +// (x * y) + a = Fma x y a +// a + (x * y) = Fma x y a +bool MergeMulAddArithmetic(IRContext* context, Instruction* inst, + const std::vector<const analysis::Constant*>&) { + assert(inst->opcode() == SpvOpFAdd); + + if (!inst->IsFloatingPointFoldingAllowed()) { + return false; + } + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + for (int i = 0; i < 2; i++) { + uint32_t op_id = inst->GetSingleWordInOperand(i); + Instruction* op_inst = def_use_mgr->GetDef(op_id); + + if (op_inst->opcode() != SpvOpFMul) { + continue; + } + + if (!op_inst->IsFloatingPointFoldingAllowed()) { + continue; + } + + uint32_t x = op_inst->GetSingleWordInOperand(0); + uint32_t y = op_inst->GetSingleWordInOperand(1); + uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2); + ReplaceWithFma(inst, x, y, a); + return true; + } + return false; +} + +// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets +// negated if |negate_addition| is true, otherwise |x| gets negated. +void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y, + uint32_t a, bool negate_addition) { + uint32_t ext = + sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + + if (ext == 0) { + sub->context()->AddExtInstImport("GLSL.std.450"); + ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); + assert(ext != 0 && + "Could not add the GLSL.std.450 extended instruction set"); + } + + InstructionBuilder ir_builder( + sub->context(), sub, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + + Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), SpvOpFNegate, + negate_addition ? a : x); + uint32_t neg_op = neg->result_id(); // -a : -x + + std::vector<Operand> operands; + operands.push_back({SPV_OPERAND_TYPE_ID, {ext}}); + operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}}); + operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}}); + operands.push_back({SPV_OPERAND_TYPE_ID, {y}}); + operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}}); + + sub->SetOpcode(SpvOpExtInst); + sub->SetInOperands(std::move(operands)); +} + +// Folds a multiply and subtract into an Fma and negation. +// +// Cases: +// (x * y) - a = Fma x y -a +// a - (x * y) = Fma -x y a +bool MergeMulSubArithmetic(IRContext* context, Instruction* sub, + const std::vector<const analysis::Constant*>&) { + assert(sub->opcode() == SpvOpFSub); + + if (!sub->IsFloatingPointFoldingAllowed()) { + return false; + } + + analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); + for (int i = 0; i < 2; i++) { + uint32_t op_id = sub->GetSingleWordInOperand(i); + Instruction* mul = def_use_mgr->GetDef(op_id); + + if (mul->opcode() != SpvOpFMul) { + continue; + } + + if (!mul->IsFloatingPointFoldingAllowed()) { + continue; + } + + uint32_t x = mul->GetSingleWordInOperand(0); + uint32_t y = mul->GetSingleWordInOperand(1); + uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2); + ReplaceWithFmaAndNegate(sub, x, y, a, i == 0); + return true; + } + return false; +} + FoldingRule IntMultipleBy1() { return [](IRContext*, Instruction* inst, const std::vector<const analysis::Constant*>& constants) { @@ -1573,6 +1704,57 @@ bool CompositeConstructFeedingExtract( return true; } +// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or +// OpCompositeExtract instruction, and returns the type of the final element +// being accessed. +const analysis::Type* GetElementType(uint32_t type_id, + Instruction::iterator start, + Instruction::iterator end, + const analysis::TypeManager* type_mgr) { + const analysis::Type* type = type_mgr->GetType(type_id); + for (auto index : make_range(std::move(start), std::move(end))) { + assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER && + index.words.size() == 1); + if (auto* array_type = type->AsArray()) { + type = array_type->element_type(); + } else if (auto* matrix_type = type->AsMatrix()) { + type = matrix_type->element_type(); + } else if (auto* struct_type = type->AsStruct()) { + type = struct_type->element_types()[index.words[0]]; + } else { + type = nullptr; + } + } + return type; +} + +// Returns true of |inst_1| and |inst_2| have the same indexes that will be used +// to index into a composite object, excluding the last index. The two +// instructions must have the same opcode, and be either OpCompositeExtract or +// OpCompositeInsert instructions. +bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) { + assert(inst_1->opcode() == inst_2->opcode() && + "Expecting the opcodes to be the same."); + assert((inst_1->opcode() == SpvOpCompositeInsert || + inst_1->opcode() == SpvOpCompositeExtract) && + "Instructions must be OpCompositeInsert or OpCompositeExtract."); + + if (inst_1->NumInOperands() != inst_2->NumInOperands()) { + return false; + } + + uint32_t first_index_position = + (inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1); + for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1; + i++) { + if (inst_1->GetSingleWordInOperand(i) != + inst_2->GetSingleWordInOperand(i)) { + return false; + } + } + return true; +} + // If the OpCompositeConstruct is simply putting back together elements that // where extracted from the same source, we can simply reuse the source. // @@ -1595,19 +1777,24 @@ bool CompositeExtractFeedingConstruct( // - extractions // - extracting the same position they are inserting // - all extract from the same id. + Instruction* first_element_inst = nullptr; for (uint32_t i = 0; i < inst->NumInOperands(); ++i) { const uint32_t element_id = inst->GetSingleWordInOperand(i); Instruction* element_inst = def_use_mgr->GetDef(element_id); + if (first_element_inst == nullptr) { + first_element_inst = element_inst; + } if (element_inst->opcode() != SpvOpCompositeExtract) { return false; } - if (element_inst->NumInOperands() != 2) { + if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) { return false; } - if (element_inst->GetSingleWordInOperand(1) != i) { + if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() - + 1) != i) { return false; } @@ -1623,13 +1810,31 @@ bool CompositeExtractFeedingConstruct( // The last check it to see that the object being extracted from is the // correct type. Instruction* original_inst = def_use_mgr->GetDef(original_id); - if (original_inst->type_id() != inst->type_id()) { + analysis::TypeManager* type_mgr = context->get_type_mgr(); + const analysis::Type* original_type = + GetElementType(original_inst->type_id(), first_element_inst->begin() + 3, + first_element_inst->end() - 1, type_mgr); + + if (original_type == nullptr) { return false; } - // Simplify by using the original object. - inst->SetOpcode(SpvOpCopyObject); - inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}}); + if (inst->type_id() != type_mgr->GetId(original_type)) { + return false; + } + + if (first_element_inst->NumInOperands() == 2) { + // Simplify by using the original object. + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}}); + return true; + } + + // Copies the original id and all indexes except for the last to the new + // extract instruction. + inst->SetOpcode(SpvOpCompositeExtract); + inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2, + first_element_inst->end() - 1)); return true; } @@ -1833,6 +2038,139 @@ FoldingRule FMixFeedingExtract() { }; } +// Returns the number of elements in the composite type |type|. Returns 0 if +// |type| is a scalar value. +uint32_t GetNumberOfElements(const analysis::Type* type) { + if (auto* vector_type = type->AsVector()) { + return vector_type->element_count(); + } + if (auto* matrix_type = type->AsMatrix()) { + return matrix_type->element_count(); + } + if (auto* struct_type = type->AsStruct()) { + return static_cast<uint32_t>(struct_type->element_types().size()); + } + if (auto* array_type = type->AsArray()) { + return array_type->length_info().words[0]; + } + return 0; +} + +// Returns a map with the set of values that were inserted into an object by +// the chain of OpCompositeInsertInstruction starting with |inst|. +// The map will map the index to the value inserted at that index. +std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) { + analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr(); + std::map<uint32_t, uint32_t> values_inserted; + Instruction* current_inst = inst; + while (current_inst->opcode() == SpvOpCompositeInsert) { + if (current_inst->NumInOperands() > inst->NumInOperands()) { + // This is the catch the case + // %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0 + // %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0 + // %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1 + // In this case we cannot do a single construct to get the matrix. + uint32_t partially_inserted_element_index = + current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1); + if (values_inserted.count(partially_inserted_element_index) == 0) + return {}; + } + if (HaveSameIndexesExceptForLast(inst, current_inst)) { + values_inserted.insert( + {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() - + 1), + current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)}); + } + current_inst = def_use_mgr->GetDef( + current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx)); + } + return values_inserted; +} + +// Returns true of there is an entry in |values_inserted| for every element of +// |Type|. +bool DoInsertedValuesCoverEntireObject( + const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) { + uint32_t container_size = GetNumberOfElements(type); + if (container_size != values_inserted.size()) { + return false; + } + + if (values_inserted.rbegin()->first >= container_size) { + return false; + } + return true; +} + +// Returns the type of the element that immediately contains the element being +// inserted by the OpCompositeInsert instruction |inst|. +const analysis::Type* GetContainerType(Instruction* inst) { + assert(inst->opcode() == SpvOpCompositeInsert); + analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); + return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1, + type_mgr); +} + +// Returns an OpCompositeConstruct instruction that build an object with +// |type_id| out of the values in |values_inserted|. Each value will be +// placed at the index corresponding to the value. The new instruction will +// be placed before |insert_before|. +Instruction* BuildCompositeConstruct( + uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted, + Instruction* insert_before) { + InstructionBuilder ir_builder( + insert_before->context(), insert_before, + IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); + + std::vector<uint32_t> ids_in_order; + for (auto it : values_inserted) { + ids_in_order.push_back(it.second); + } + Instruction* construct = + ir_builder.AddCompositeConstruct(type_id, ids_in_order); + return construct; +} + +// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same +// object as |inst| with final index removed. If the resulting +// OpCompositeInsert instruction would have no remaining indexes, the +// instruction is replaced with an OpCopyObject instead. +void InsertConstructedObject(Instruction* inst, const Instruction* construct) { + if (inst->NumInOperands() == 3) { + inst->SetOpcode(SpvOpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}}); + } else { + inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()}); + inst->RemoveOperand(inst->NumOperands() - 1); + } +} + +// Replaces a series of |OpCompositeInsert| instruction that cover the entire +// object with an |OpCompositeConstruct|. +bool CompositeInsertToCompositeConstruct( + IRContext* context, Instruction* inst, + const std::vector<const analysis::Constant*>&) { + assert(inst->opcode() == SpvOpCompositeInsert && + "Wrong opcode. Should be OpCompositeInsert."); + if (inst->NumInOperands() < 3) return false; + + std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst); + const analysis::Type* container_type = GetContainerType(inst); + if (container_type == nullptr) { + return false; + } + + if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) { + return false; + } + + analysis::TypeManager* type_mgr = context->get_type_mgr(); + Instruction* construct = BuildCompositeConstruct( + type_mgr->GetId(container_type), values_inserted, inst); + InsertConstructedObject(inst, construct); + return true; +} + FoldingRule RedundantPhi() { // An OpPhi instruction where all values are the same or the result of the phi // itself, can be replaced by the value itself. @@ -2368,7 +2706,7 @@ FoldingRule VectorShuffleFeedingShuffle() { // fold. return false; } - } else { + } else if (component_index != undef_literal) { if (new_feeder_id == 0) { // First time through, save the id of the operand the element comes // from. @@ -2382,7 +2720,7 @@ FoldingRule VectorShuffleFeedingShuffle() { component_index -= feeder_op0_length; } - if (!feeder_is_op0) { + if (!feeder_is_op0 && component_index != undef_literal) { component_index += op0_length; } } @@ -2410,7 +2748,8 @@ FoldingRule VectorShuffleFeedingShuffle() { if (adjustment != 0) { for (uint32_t i = 2; i < new_operands.size(); i++) { - if (inst->GetSingleWordInOperand(i) >= op0_length) { + uint32_t operand = inst->GetSingleWordInOperand(i); + if (operand >= op0_length && operand != undef_literal) { new_operands[i].words[0] -= adjustment; } } @@ -2533,6 +2872,8 @@ void FoldingRules::AddFoldingRules() { rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract()); rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract()); + rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct); + rules_[SpvOpDot].push_back(DotProductDoingExtract()); rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands()); @@ -2543,6 +2884,7 @@ void FoldingRules::AddFoldingRules() { rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic()); rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic()); rules_[SpvOpFAdd].push_back(FactorAddMuls()); + rules_[SpvOpFAdd].push_back(MergeMulAddArithmetic); rules_[SpvOpFDiv].push_back(RedundantFDiv()); rules_[SpvOpFDiv].push_back(ReciprocalFDiv()); @@ -2563,6 +2905,7 @@ void FoldingRules::AddFoldingRules() { rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic()); rules_[SpvOpFSub].push_back(MergeSubAddArithmetic()); rules_[SpvOpFSub].push_back(MergeSubSubArithmetic()); + rules_[SpvOpFSub].push_back(MergeMulSubArithmetic); rules_[SpvOpIAdd].push_back(RedundantIAdd()); rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic()); |