aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2020-04-29 16:09:43 -0700
committerRiver Riddle <riddleriver@gmail.com>2020-04-29 16:48:15 -0700
commit0752d98ccf8771b41718170d46d11f4020b62818 (patch)
tree07bff970dd5e8b175b454c2b87aae774a457329e
parent91dae5708708c0c0b3e2383b419005bfe0402ae0 (diff)
downloadllvm-project-0752d98ccf8771b41718170d46d11f4020b62818.tar.gz
[mlir] Simplify BranchOpInterface by using MutableOperandRange
This range allows for performing many different operations on successor operands, including erasing/adding/setting. This removes the need for the explicit canEraseSuccessorOperand and eraseSuccessorOperand methods. Differential Revision: https://reviews.llvm.org/D79077
-rw-r--r--flang/include/flang/Optimizer/Dialect/FIROps.td1
-rw-r--r--flang/lib/Optimizer/Dialect/FIROps.cpp60
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/IR/Ops.td4
-rw-r--r--mlir/include/mlir/IR/OperationSupport.h7
-rw-r--r--mlir/include/mlir/Interfaces/ControlFlowInterfaces.h5
-rw-r--r--mlir/include/mlir/Interfaces/ControlFlowInterfaces.td26
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp21
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp17
-rw-r--r--mlir/lib/Dialect/StandardOps/IR/Ops.cpp15
-rw-r--r--mlir/lib/IR/OperationSupport.cpp12
-rw-r--r--mlir/lib/Interfaces/ControlFlowInterfaces.cpp33
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp9
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp7
-rw-r--r--mlir/tools/mlir-tblgen/OpInterfacesGen.cpp8
14 files changed, 102 insertions, 123 deletions
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 46c39a1d498b..383256c3916f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -585,6 +585,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
+ using BranchOpInterfaceTrait::getSuccessorOperands;
// Helper function to deal with Optional operand forms
void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) {
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 1dd15fc959be..e2d94885e8fc 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -997,14 +997,26 @@ static constexpr llvm::StringRef getTargetOffsetAttr() {
return "target_operand_offsets";
}
-template <typename A>
+template <typename A, typename... AdditionalArgs>
static A getSubOperands(unsigned pos, A allArgs,
- mlir::DenseIntElementsAttr ranges) {
+ mlir::DenseIntElementsAttr ranges,
+ AdditionalArgs &&... additionalArgs) {
unsigned start = 0;
for (unsigned i = 0; i < pos; ++i)
start += (*(ranges.begin() + i)).getZExtValue();
- unsigned end = start + (*(ranges.begin() + pos)).getZExtValue();
- return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)};
+ return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(),
+ std::forward<AdditionalArgs>(additionalArgs)...);
+}
+
+static mlir::MutableOperandRange
+getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands,
+ StringRef offsetAttr) {
+ Operation *owner = operands.getOwner();
+ NamedAttribute targetOffsetAttr =
+ *owner->getMutableAttrDict().getNamed(offsetAttr);
+ return getSubOperands(
+ pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(),
+ mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr));
}
static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) {
@@ -1020,10 +1032,10 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
-llvm::Optional<mlir::OperandRange>
-fir::SelectOp::getSuccessorOperands(unsigned oper) {
- auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
- return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
+ return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+ getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1035,8 +1047,6 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
-bool fir::SelectOp::canEraseSuccessorOperand() { return true; }
-
unsigned fir::SelectOp::targetOffsetSize() {
return denseElementsSize(
getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()));
@@ -1061,10 +1071,10 @@ fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
}
-llvm::Optional<mlir::OperandRange>
-fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
- auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
- return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
+ return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+ getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1076,8 +1086,6 @@ fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
-bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; }
-
// parser for fir.select_case Op
static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
@@ -1254,10 +1262,10 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
-llvm::Optional<mlir::OperandRange>
-fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
- auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
- return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
+ return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+ getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1269,8 +1277,6 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
-bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; }
-
unsigned fir::SelectRankOp::targetOffsetSize() {
return denseElementsSize(
getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()));
@@ -1290,10 +1296,10 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
-llvm::Optional<mlir::OperandRange>
-fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
- auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
- return {getSubOperands(oper, targetArgs(), a)};
+llvm::Optional<mlir::MutableOperandRange>
+fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
+ return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
+ getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@@ -1305,8 +1311,6 @@ fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
-bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; }
-
static ParseResult parseSelectType(OpAsmParser &parser,
OperationState &result) {
mlir::OpAsmParser::OperandType selector;
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 48ed99051642..87f8e629a5c6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1074,7 +1074,7 @@ def CondBranchOp : Std_Op<"cond_br",
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
- eraseSuccessorOperand(trueIndex, index);
+ trueDestOperandsMutable().erase(index);
}
// Accessors for operands to the 'false' destination.
@@ -1093,7 +1093,7 @@ def CondBranchOp : Std_Op<"cond_br",
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
- eraseSuccessorOperand(falseIndex, index);
+ falseDestOperandsMutable().erase(index);
}
private:
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 2214b5db2f20..edfe89ad97f2 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -678,6 +678,10 @@ public:
ArrayRef<OperandSegment> operandSegments = llvm::None);
MutableOperandRange(Operation *owner);
+ /// Slice this range into a sub range, with the additional operand segment.
+ MutableOperandRange slice(unsigned subStart, unsigned subLen,
+ Optional<OperandSegment> segment = llvm::None);
+
/// Append the given values to the range.
void append(ValueRange values);
@@ -699,6 +703,9 @@ public:
/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;
+ /// Returns the owning operation.
+ Operation *getOwner() const { return owner; }
+
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index e22454538343..e18c46f745a2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -24,11 +24,6 @@ class BranchOpInterface;
//===----------------------------------------------------------------------===//
namespace detail {
-/// Erase an operand from a branch operation that is used as a successor
-/// operand. `operandIndex` is the operand within `operands` to be erased.
-void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex,
- Operation *op);
-
/// Return the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if `operandIndex` is within the range of `operands`, or None if
/// `operandIndex` isn't a successor operand index.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 5c02482394b7..591ca11830e9 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -27,29 +27,25 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
}];
let methods = [
InterfaceMethod<[{
- Returns a set of values that correspond to the arguments to the
+ Returns a mutable range of operands that correspond to the arguments of
successor at the given index. Returns None if the operands to the
successor are non-materialized values, i.e. they are internal to the
operation.
}],
- "Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index)
+ "Optional<MutableOperandRange>", "getMutableSuccessorOperands",
+ (ins "unsigned":$index)
>,
InterfaceMethod<[{
- Return true if this operation can erase an operand to a successor block.
- }],
- "bool", "canEraseSuccessorOperand"
- >,
- InterfaceMethod<[{
- Erase the operand at `operandIndex` from the `index`-th successor. This
- should only be called if `canEraseSuccessorOperand` returns true.
+ Returns a range of operands that correspond to the arguments of
+ successor at the given index. Returns None if the operands to the
+ successor are non-materialized values, i.e. they are internal to the
+ operation.
}],
- "void", "eraseSuccessorOperand",
- (ins "unsigned":$index, "unsigned":$operandIndex), [{}],
- /*defaultImplementation=*/[{
+ "Optional<OperandRange>", "getSuccessorOperands",
+ (ins "unsigned":$index), [{}], [{
ConcreteOp *op = static_cast<ConcreteOp *>(this);
- Optional<OperandRange> operands = op->getSuccessorOperands(index);
- assert(operands && "unable to query operands for successor");
- detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op);
+ auto operands = op->getMutableSuccessorOperands(index);
+ return operands ? Optional<OperandRange>(*operands) : llvm::None;
}]
>,
InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0a462d0239e3..5c112710ec55 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -160,24 +160,22 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
// LLVM::BrOp
//===----------------------------------------------------------------------===//
-Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+BrOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return getOperands();
+ return destOperandsMutable();
}
-bool BrOp::canEraseSuccessorOperand() { return true; }
-
//===----------------------------------------------------------------------===//
// LLVM::CondBrOp
//===----------------------------------------------------------------------===//
-Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+CondBrOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == 0 ? trueDestOperands() : falseDestOperands();
+ return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
}
-bool CondBrOp::canEraseSuccessorOperand() { return true; }
-
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
@@ -257,13 +255,12 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
/// LLVM::InvokeOp
///===---------------------------------------------------------------------===//
-Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+InvokeOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == 0 ? normalDestOperands() : unwindDestOperands();
+ return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
}
-bool InvokeOp::canEraseSuccessorOperand() { return true; }
-
static LogicalResult verify(InvokeOp op) {
if (op.getNumResults() > 1)
return op.emitOpError("must have 0 or 1 result");
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ed98d3745d6f..5d4e309a2e96 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -987,26 +987,23 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
// spv.BranchOp
//===----------------------------------------------------------------------===//
-Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return getOperands();
+ return targetOperandsMutable();
}
-bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }
-
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
-Optional<OperandRange>
-spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
assert(index < 2 && "invalid successor index");
- return index == kTrueIndex ? getTrueBlockArguments()
- : getFalseBlockArguments();
+ return index == kTrueIndex ? trueTargetOperandsMutable()
+ : falseTargetOperandsMutable();
}
-bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }
-
static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
OperationState &state) {
auto &builder = parser.getBuilder();
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 85efc4391234..8ef24e239152 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -677,13 +677,12 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
-Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return getOperands();
+ return destOperandsMutable();
}
-bool BranchOp::canEraseSuccessorOperand() { return true; }
-
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
//===----------------------------------------------------------------------===//
@@ -1021,13 +1020,13 @@ void CondBranchOp::getCanonicalizationPatterns(
SimplifyCondBranchIdenticalSuccessors>(context);
}
-Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+CondBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
- return index == trueIndex ? getTrueOperands() : getFalseOperands();
+ return index == trueIndex ? trueDestOperandsMutable()
+ : falseDestOperandsMutable();
}
-bool CondBranchOp::canEraseSuccessorOperand() { return true; }
-
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>())
return condAttr.getValue() ? trueDest() : falseDest();
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 83b4f0bf176e..a08762326143 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -287,6 +287,18 @@ MutableOperandRange::MutableOperandRange(
MutableOperandRange::MutableOperandRange(Operation *owner)
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
+/// Slice this range into a sub range, with the additional operand segment.
+MutableOperandRange
+MutableOperandRange::slice(unsigned subStart, unsigned subLen,
+ Optional<OperandSegment> segment) {
+ assert((subStart + subLen) <= length && "invalid sub-range");
+ MutableOperandRange subSlice(owner, start + subStart, subLen,
+ operandSegments);
+ if (segment)
+ subSlice.operandSegments.push_back(*segment);
+ return subSlice;
+}
+
/// Append the given values to the range.
void MutableOperandRange::append(ValueRange values) {
if (values.empty())
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 746dd402a35a..c1fa833f26da 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -21,39 +21,6 @@ using namespace mlir;
// BranchOpInterface
//===----------------------------------------------------------------------===//
-/// Erase an operand from a branch operation that is used as a successor
-/// operand. 'operandIndex' is the operand within 'operands' to be erased.
-void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
- unsigned operandIndex,
- Operation *op) {
- assert(operandIndex < operands.size() &&
- "invalid index for successor operands");
-
- // Erase the operand from the operation.
- size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
- op->eraseOperand(fullOperandIndex);
-
- // If this operation has an OperandSegmentSizeAttr, keep it up to date.
- auto operandSegmentAttr =
- op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
- if (!operandSegmentAttr)
- return;
-
- // Find the segment containing the full operand index and decrement it.
- // TODO: This seems like a general utility that could be added somewhere.
- SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
- unsigned currentSize = 0;
- for (unsigned i = 0, e = values.size(); i != e; ++i) {
- currentSize += values[i];
- if (fullOperandIndex < currentSize) {
- --values[i];
- break;
- }
- }
- op->setAttr("operand_segment_sizes",
- DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
-}
-
/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if 'operandIndex' is within the range of 'operands', or None if
/// `operandIndex` isn't a successor operand index.
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 162091cd53de..7a00032650b2 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -209,7 +209,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
// Check to see if we can reason about the successor operands and mutate them.
BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
- if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) {
+ if (!branchInterface) {
for (Block *successor : op->getSuccessors())
for (BlockArgument arg : successor->getArguments())
liveMap.setProvedLive(arg);
@@ -219,7 +219,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
// If we can't reason about the operands to a successor, conservatively mark
// all arguments as live.
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
- if (!branchInterface.getSuccessorOperands(i))
+ if (!branchInterface.getMutableSuccessorOperands(i))
for (BlockArgument arg : op->getSuccessor(i)->getArguments())
liveMap.setProvedLive(arg);
}
@@ -278,7 +278,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// since it will promote later operands of the terminator being erased
// first, reducing the quadratic-ness.
unsigned succ = succE - succI - 1;
- Optional<OperandRange> succOperands = branchOp.getSuccessorOperands(succ);
+ Optional<MutableOperandRange> succOperands =
+ branchOp.getMutableSuccessorOperands(succ);
if (!succOperands)
continue;
Block *successor = terminator->getSuccessor(succ);
@@ -288,7 +289,7 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// shifting later args when earlier args are erased.
unsigned arg = argE - argI - 1;
if (!liveMap.wasProvenLive(successor->getArgument(arg)))
- branchOp.eraseSuccessorOperand(succ, arg);
+ succOperands->erase(arg);
}
}
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 4c67310e3705..1a40f9989eae 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -167,13 +167,12 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
// TestBranchOp
//===----------------------------------------------------------------------===//
-Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
+Optional<MutableOperandRange>
+TestBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return getOperands();
+ return targetOperandsMutable();
}
-bool TestBranchOp::canEraseSuccessorOperand() { return true; }
-
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 12ba8d43c9c1..ae86f713c462 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -146,7 +146,7 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
StringRef interfaceName,
StringRef interfaceTraitsName) {
os << " template <typename ConcreteOp>\n "
- << llvm::formatv("struct Trait : public OpInterface<{0},"
+ << llvm::formatv("struct {0}Trait : public OpInterface<{0},"
" detail::{1}>::Trait<ConcreteOp> {{\n",
interfaceName, interfaceTraitsName);
@@ -171,13 +171,17 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
tblgen::FmtContext traitCtx;
traitCtx.withOp("op");
if (auto verify = interface.getVerify()) {
- os << " static LogicalResult verifyTrait(Operation* op) {\n"
+ os << " static LogicalResult verifyTrait(Operation* op) {\n"
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
}
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << extraTraitDecls << "\n";
os << " };\n";
+
+ // Emit a utility using directive for the trait class.
+ os << " template <typename ConcreteOp>\n "
+ << llvm::formatv("using Trait = {0}Trait<ConcreteOp>;\n", interfaceName);
}
static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {