diff options
author | River Riddle <riddleriver@gmail.com> | 2020-04-29 16:09:43 -0700 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2020-04-29 16:48:15 -0700 |
commit | 0752d98ccf8771b41718170d46d11f4020b62818 (patch) | |
tree | 07bff970dd5e8b175b454c2b87aae774a457329e | |
parent | 91dae5708708c0c0b3e2383b419005bfe0402ae0 (diff) | |
download | llvm-project-0752d98ccf8771b41718170d46d11f4020b62818.tar.gz |
[mlir] Simplify BranchOpInterface by using MutableOperandRange
This range allows for performing many different operations on successor operands, including erasing/adding/setting. This removes the need for the explicit canEraseSuccessorOperand and eraseSuccessorOperand methods.
Differential Revision: https://reviews.llvm.org/D79077
-rw-r--r-- | flang/include/flang/Optimizer/Dialect/FIROps.td | 1 | ||||
-rw-r--r-- | flang/lib/Optimizer/Dialect/FIROps.cpp | 60 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 4 | ||||
-rw-r--r-- | mlir/include/mlir/IR/OperationSupport.h | 7 | ||||
-rw-r--r-- | mlir/include/mlir/Interfaces/ControlFlowInterfaces.h | 5 | ||||
-rw-r--r-- | mlir/include/mlir/Interfaces/ControlFlowInterfaces.td | 26 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 21 | ||||
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 17 | ||||
-rw-r--r-- | mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 15 | ||||
-rw-r--r-- | mlir/lib/IR/OperationSupport.cpp | 12 | ||||
-rw-r--r-- | mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 33 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/RegionUtils.cpp | 9 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestDialect.cpp | 7 | ||||
-rw-r--r-- | mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 8 |
14 files changed, 102 insertions, 123 deletions
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 46c39a1d498b..383256c3916f 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -585,6 +585,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> : llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands( llvm::ArrayRef<mlir::Value> operands, unsigned cond); + using BranchOpInterfaceTrait::getSuccessorOperands; // Helper function to deal with Optional operand forms void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) { diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 1dd15fc959be..e2d94885e8fc 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -997,14 +997,26 @@ static constexpr llvm::StringRef getTargetOffsetAttr() { return "target_operand_offsets"; } -template <typename A> +template <typename A, typename... AdditionalArgs> static A getSubOperands(unsigned pos, A allArgs, - mlir::DenseIntElementsAttr ranges) { + mlir::DenseIntElementsAttr ranges, + AdditionalArgs &&... additionalArgs) { unsigned start = 0; for (unsigned i = 0; i < pos; ++i) start += (*(ranges.begin() + i)).getZExtValue(); - unsigned end = start + (*(ranges.begin() + pos)).getZExtValue(); - return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)}; + return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), + std::forward<AdditionalArgs>(additionalArgs)...); +} + +static mlir::MutableOperandRange +getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, + StringRef offsetAttr) { + Operation *owner = operands.getOwner(); + NamedAttribute targetOffsetAttr = + *owner->getMutableAttrDict().getNamed(offsetAttr); + return getSubOperands( + pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), + mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); } static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { @@ -1020,10 +1032,10 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { return {}; } -llvm::Optional<mlir::OperandRange> -fir::SelectOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional<mlir::MutableOperandRange> +fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional<llvm::ArrayRef<mlir::Value>> @@ -1035,8 +1047,6 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectOp::canEraseSuccessorOperand() { return true; } - unsigned fir::SelectOp::targetOffsetSize() { return denseElementsSize( getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); @@ -1061,10 +1071,10 @@ fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } -llvm::Optional<mlir::OperandRange> -fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional<mlir::MutableOperandRange> +fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional<llvm::ArrayRef<mlir::Value>> @@ -1076,8 +1086,6 @@ fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; } - // parser for fir.select_case Op static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, mlir::OperationState &result) { @@ -1254,10 +1262,10 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { return {}; } -llvm::Optional<mlir::OperandRange> -fir::SelectRankOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional<mlir::MutableOperandRange> +fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional<llvm::ArrayRef<mlir::Value>> @@ -1269,8 +1277,6 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; } - unsigned fir::SelectRankOp::targetOffsetSize() { return denseElementsSize( getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); @@ -1290,10 +1296,10 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { return {}; } -llvm::Optional<mlir::OperandRange> -fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { - auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); - return {getSubOperands(oper, targetArgs(), a)}; +llvm::Optional<mlir::MutableOperandRange> +fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { + return ::getMutableSuccessorOperands(oper, targetArgsMutable(), + getTargetOffsetAttr()); } llvm::Optional<llvm::ArrayRef<mlir::Value>> @@ -1305,8 +1311,6 @@ fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } -bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; } - static ParseResult parseSelectType(OpAsmParser &parser, OperationState &result) { mlir::OpAsmParser::OperandType selector; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index 48ed99051642..87f8e629a5c6 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1074,7 +1074,7 @@ def CondBranchOp : Std_Op<"cond_br", /// Erase the operand at 'index' from the true operand list. void eraseTrueOperand(unsigned index) { - eraseSuccessorOperand(trueIndex, index); + trueDestOperandsMutable().erase(index); } // Accessors for operands to the 'false' destination. @@ -1093,7 +1093,7 @@ def CondBranchOp : Std_Op<"cond_br", /// Erase the operand at 'index' from the false operand list. void eraseFalseOperand(unsigned index) { - eraseSuccessorOperand(falseIndex, index); + falseDestOperandsMutable().erase(index); } private: diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 2214b5db2f20..edfe89ad97f2 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -678,6 +678,10 @@ public: ArrayRef<OperandSegment> operandSegments = llvm::None); MutableOperandRange(Operation *owner); + /// Slice this range into a sub range, with the additional operand segment. + MutableOperandRange slice(unsigned subStart, unsigned subLen, + Optional<OperandSegment> segment = llvm::None); + /// Append the given values to the range. void append(ValueRange values); @@ -699,6 +703,9 @@ public: /// Allow implicit conversion to an OperandRange. operator OperandRange() const; + /// Returns the owning operation. + Operation *getOwner() const { return owner; } + private: /// Update the length of this range to the one provided. void updateLength(unsigned newLength); diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index e22454538343..e18c46f745a2 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -24,11 +24,6 @@ class BranchOpInterface; //===----------------------------------------------------------------------===// namespace detail { -/// Erase an operand from a branch operation that is used as a successor -/// operand. `operandIndex` is the operand within `operands` to be erased. -void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex, - Operation *op); - /// Return the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if `operandIndex` is within the range of `operands`, or None if /// `operandIndex` isn't a successor operand index. diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 5c02482394b7..591ca11830e9 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -27,29 +27,25 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { }]; let methods = [ InterfaceMethod<[{ - Returns a set of values that correspond to the arguments to the + Returns a mutable range of operands that correspond to the arguments of successor at the given index. Returns None if the operands to the successor are non-materialized values, i.e. they are internal to the operation. }], - "Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index) + "Optional<MutableOperandRange>", "getMutableSuccessorOperands", + (ins "unsigned":$index) >, InterfaceMethod<[{ - Return true if this operation can erase an operand to a successor block. - }], - "bool", "canEraseSuccessorOperand" - >, - InterfaceMethod<[{ - Erase the operand at `operandIndex` from the `index`-th successor. This - should only be called if `canEraseSuccessorOperand` returns true. + Returns a range of operands that correspond to the arguments of + successor at the given index. Returns None if the operands to the + successor are non-materialized values, i.e. they are internal to the + operation. }], - "void", "eraseSuccessorOperand", - (ins "unsigned":$index, "unsigned":$operandIndex), [{}], - /*defaultImplementation=*/[{ + "Optional<OperandRange>", "getSuccessorOperands", + (ins "unsigned":$index), [{}], [{ ConcreteOp *op = static_cast<ConcreteOp *>(this); - Optional<OperandRange> operands = op->getSuccessorOperands(index); - assert(operands && "unable to query operands for successor"); - detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op); + auto operands = op->getMutableSuccessorOperands(index); + return operands ? Optional<OperandRange>(*operands) : llvm::None; }] >, InterfaceMethod<[{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 0a462d0239e3..5c112710ec55 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -160,24 +160,22 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) { // LLVM::BrOp //===----------------------------------------------------------------------===// -Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +BrOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return destOperandsMutable(); } -bool BrOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // LLVM::CondBrOp //===----------------------------------------------------------------------===// -Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +CondBrOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? trueDestOperands() : falseDestOperands(); + return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); } -bool CondBrOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::LoadOp. //===----------------------------------------------------------------------===// @@ -257,13 +255,12 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { /// LLVM::InvokeOp ///===---------------------------------------------------------------------===// -Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +InvokeOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? normalDestOperands() : unwindDestOperands(); + return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable(); } -bool InvokeOp::canEraseSuccessorOperand() { return true; } - static LogicalResult verify(InvokeOp op) { if (op.getNumResults() > 1) return op.emitOpError("must have 0 or 1 result"); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index ed98d3745d6f..5d4e309a2e96 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -987,26 +987,23 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) { // spv.BranchOp //===----------------------------------------------------------------------===// -Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +spirv::BranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return targetOperandsMutable(); } -bool spirv::BranchOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// -Optional<OperandRange> -spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) { assert(index < 2 && "invalid successor index"); - return index == kTrueIndex ? getTrueBlockArguments() - : getFalseBlockArguments(); + return index == kTrueIndex ? trueTargetOperandsMutable() + : falseTargetOperandsMutable(); } -bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; } - static ParseResult parseBranchConditionalOp(OpAsmParser &parser, OperationState &state) { auto &builder = parser.getBuilder(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 85efc4391234..8ef24e239152 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -677,13 +677,12 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, context); } -Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +BranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return destOperandsMutable(); } -bool BranchOp::canEraseSuccessorOperand() { return true; } - Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); } //===----------------------------------------------------------------------===// @@ -1021,13 +1020,13 @@ void CondBranchOp::getCanonicalizationPatterns( SimplifyCondBranchIdenticalSuccessors>(context); } -Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +CondBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); - return index == trueIndex ? getTrueOperands() : getFalseOperands(); + return index == trueIndex ? trueDestOperandsMutable() + : falseDestOperandsMutable(); } -bool CondBranchOp::canEraseSuccessorOperand() { return true; } - Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>()) return condAttr.getValue() ? trueDest() : falseDest(); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 83b4f0bf176e..a08762326143 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -287,6 +287,18 @@ MutableOperandRange::MutableOperandRange( MutableOperandRange::MutableOperandRange(Operation *owner) : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} +/// Slice this range into a sub range, with the additional operand segment. +MutableOperandRange +MutableOperandRange::slice(unsigned subStart, unsigned subLen, + Optional<OperandSegment> segment) { + assert((subStart + subLen) <= length && "invalid sub-range"); + MutableOperandRange subSlice(owner, start + subStart, subLen, + operandSegments); + if (segment) + subSlice.operandSegments.push_back(*segment); + return subSlice; +} + /// Append the given values to the range. void MutableOperandRange::append(ValueRange values) { if (values.empty()) diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index 746dd402a35a..c1fa833f26da 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -21,39 +21,6 @@ using namespace mlir; // BranchOpInterface //===----------------------------------------------------------------------===// -/// Erase an operand from a branch operation that is used as a successor -/// operand. 'operandIndex' is the operand within 'operands' to be erased. -void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands, - unsigned operandIndex, - Operation *op) { - assert(operandIndex < operands.size() && - "invalid index for successor operands"); - - // Erase the operand from the operation. - size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex; - op->eraseOperand(fullOperandIndex); - - // If this operation has an OperandSegmentSizeAttr, keep it up to date. - auto operandSegmentAttr = - op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes"); - if (!operandSegmentAttr) - return; - - // Find the segment containing the full operand index and decrement it. - // TODO: This seems like a general utility that could be added somewhere. - SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>()); - unsigned currentSize = 0; - for (unsigned i = 0, e = values.size(); i != e; ++i) { - currentSize += values[i]; - if (fullOperandIndex < currentSize) { - --values[i]; - break; - } - } - op->setAttr("operand_segment_sizes", - DenseIntElementsAttr::get(operandSegmentAttr.getType(), values)); -} - /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if 'operandIndex' is within the range of 'operands', or None if /// `operandIndex` isn't a successor operand index. diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 162091cd53de..7a00032650b2 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -209,7 +209,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { // Check to see if we can reason about the successor operands and mutate them. BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op); - if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) { + if (!branchInterface) { for (Block *successor : op->getSuccessors()) for (BlockArgument arg : successor->getArguments()) liveMap.setProvedLive(arg); @@ -219,7 +219,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) { // If we can't reason about the operands to a successor, conservatively mark // all arguments as live. for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { - if (!branchInterface.getSuccessorOperands(i)) + if (!branchInterface.getMutableSuccessorOperands(i)) for (BlockArgument arg : op->getSuccessor(i)->getArguments()) liveMap.setProvedLive(arg); } @@ -278,7 +278,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator, // since it will promote later operands of the terminator being erased // first, reducing the quadratic-ness. unsigned succ = succE - succI - 1; - Optional<OperandRange> succOperands = branchOp.getSuccessorOperands(succ); + Optional<MutableOperandRange> succOperands = + branchOp.getMutableSuccessorOperands(succ); if (!succOperands) continue; Block *successor = terminator->getSuccessor(succ); @@ -288,7 +289,7 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator, // shifting later args when earlier args are erased. unsigned arg = argE - argI - 1; if (!liveMap.wasProvenLive(successor->getArgument(arg))) - branchOp.eraseSuccessorOperand(succ, arg); + succOperands->erase(arg); } } } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 4c67310e3705..1a40f9989eae 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -167,13 +167,12 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, // TestBranchOp //===----------------------------------------------------------------------===// -Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) { +Optional<MutableOperandRange> +TestBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); - return getOperands(); + return targetOperandsMutable(); } -bool TestBranchOp::canEraseSuccessorOperand() { return true; } - //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 12ba8d43c9c1..ae86f713c462 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -146,7 +146,7 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os, StringRef interfaceName, StringRef interfaceTraitsName) { os << " template <typename ConcreteOp>\n " - << llvm::formatv("struct Trait : public OpInterface<{0}," + << llvm::formatv("struct {0}Trait : public OpInterface<{0}," " detail::{1}>::Trait<ConcreteOp> {{\n", interfaceName, interfaceTraitsName); @@ -171,13 +171,17 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os, tblgen::FmtContext traitCtx; traitCtx.withOp("op"); if (auto verify = interface.getVerify()) { - os << " static LogicalResult verifyTrait(Operation* op) {\n" + os << " static LogicalResult verifyTrait(Operation* op) {\n" << std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n"; } if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) os << extraTraitDecls << "\n"; os << " };\n"; + + // Emit a utility using directive for the trait class. + os << " template <typename ConcreteOp>\n " + << llvm::formatv("using Trait = {0}Trait<ConcreteOp>;\n", interfaceName); } static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) { |