aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTres Popp <tpopp@google.com>2020-12-04 22:13:14 +0100
committerTres Popp <tpopp@google.com>2020-12-08 17:30:01 +0100
commit111ae220a3bff944e10a0760ce344630f4efc40d (patch)
treeb5399ded45e38b981a5f8baf1ba298427e94553a
parentc0428b3c0c1f3b78d39ceaf909908800fb7aabe3 (diff)
downloadllvm-project-111ae220a3bff944e10a0760ce344630f4efc40d.tar.gz
[mlir] Use rewriting infrastructure in AsyncToLLVM
This is needed so a listener hears all changes during the dialect conversion to allow correct rollbacks upon failure. Differential Revision: https://reviews.llvm.org/D92685
-rw-r--r--mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp44
1 files changed, 26 insertions, 18 deletions
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index c36cde1054ed..361bfa2b6fad 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -386,8 +386,10 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
suspendBlock};
}
-// Adds a suspension point before the `op`, and moves `op` and all operations
-// after it into the resume block. Returns a pointer to the resume block.
+// Add a LLVM coroutine suspension point to the end of suspended block, to
+// resume execution in resume block. The caller is responsible for creating the
+// two suspended/resume blocks with the desired ops contained in each block.
+// This function merely provides the required control flow logic.
//
// `coroState` must be a value returned from the call to @llvm.coro.save(...)
// intrinsic (saved coroutine state).
@@ -399,6 +401,8 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// "op"(...)
// ^cleanup: ...
// ^suspend: ...
+// ^resume:
+// "op"(...)
//
// After:
//
@@ -411,20 +415,17 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
// ^cleanup: ...
// ^suspend: ...
//
-static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
- Operation *op) {
+static void addSuspensionPoint(CoroMachinery coro, Value coroState,
+ Operation *op, Block *suspended, Block *resume,
+ OpBuilder &builder) {
+ Location loc = op->getLoc();
MLIRContext *ctx = op->getContext();
auto i1 = LLVM::LLVMType::getInt1Ty(ctx);
auto i8 = LLVM::LLVMType::getInt8Ty(ctx);
- Location loc = op->getLoc();
- Block *splitBlock = op->getBlock();
-
- // Split the block before `op`, newly added block is the resume block.
- Block *resume = splitBlock->splitBlock(op);
-
// Add a coroutine suspension in place of original `op` in the split block.
- OpBuilder builder = OpBuilder::atBlockEnd(splitBlock);
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToEnd(suspended);
auto constFalse =
builder.create<LLVM::ConstantOp>(loc, i1, builder.getBoolAttr(false));
@@ -445,7 +446,7 @@ static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
Block *resumeOrCleanup = builder.createBlock(resume);
// Suspend the coroutine ...?
- builder.setInsertionPointToEnd(splitBlock);
+ builder.setInsertionPointToEnd(suspended);
auto isNegOne = builder.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne);
builder.create<LLVM::CondBrOp>(loc, isNegOne, /*trueDest=*/coro.suspend,
@@ -457,12 +458,12 @@ static Block *addSuspensionPoint(CoroMachinery coro, Value coroState,
loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero);
builder.create<LLVM::CondBrOp>(loc, isZero, /*trueDest=*/resume,
/*falseDest=*/coro.cleanup);
-
- return resume;
}
// Outline the body region attached to the `async.execute` op into a standalone
// function.
+//
+// Note that this is not reversible transformation.
static std::pair<FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
ModuleOp module = execute.getParentOfType<ModuleOp>();
@@ -518,8 +519,11 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
builder.create<CallOp>(loc, TypeRange(), kExecute, executeArgs);
// Split the entry block before the terminator.
- Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
- entryBlock->getTerminator());
+ auto *terminatorOp = entryBlock->getTerminator();
+ Block *suspended = terminatorOp->getBlock();
+ Block *resume = suspended->splitBlock(terminatorOp);
+ addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended,
+ resume, builder);
// Await on all dependencies before starting to execute the body region.
builder.setInsertionPointToStart(resume);
@@ -740,7 +744,7 @@ public:
if (isInCoroutine) {
const CoroMachinery &coro = outlined->getSecond();
- OpBuilder builder(op);
+ OpBuilder builder(op, rewriter.getListener());
MLIRContext *ctx = op->getContext();
// A pointer to coroutine resume intrinsic wrapper.
@@ -760,8 +764,12 @@ public:
builder.create<CallOp>(loc, TypeRange(), coroAwaitFuncName,
awaitAndExecuteArgs);
+ Block *suspended = op->getBlock();
+
// Split the entry block before the await operation.
- addSuspensionPoint(coro, coroSave.getResult(0), op);
+ Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
+ addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume,
+ builder);
}
// Original operation was replaced by function call or suspension point.