diff options
author | Eugene Zhulenev <ezhulenev@google.com> | 2020-12-08 04:35:27 -0800 |
---|---|---|
committer | Eugene Zhulenev <ezhulenev@google.com> | 2020-12-08 10:30:14 -0800 |
commit | 94e645f9cce8fba26b4aec069103794f1779065f (patch) | |
tree | bf3d284bc8da54804d82bc225d7a143d33eec95d | |
parent | 4fede8bc8a015477f2a8feeb30a1d2a2e155106d (diff) | |
download | llvm-project-94e645f9cce8fba26b4aec069103794f1779065f.tar.gz |
[mlir] Async: Add numWorkerThreads argument to createAsyncParallelForPass
Add an option to pass the number of worker threads to select the number of async regions for parallel for transformation.
```
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass(int numWorkerThreads);
```
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D92835
-rw-r--r-- | mlir/include/mlir/Dialect/Async/Passes.h | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp | 9 |
2 files changed, 12 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h index 9716bde76593..ab5abdc28611 100644 --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -19,6 +19,9 @@ namespace mlir { std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass(); +std::unique_ptr<OperationPass<FuncOp>> +createAsyncParallelForPass(int numWorkerThreads); + std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass(); std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index c6508610c796..d6553974bc38 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -96,6 +96,10 @@ private: struct AsyncParallelForPass : public AsyncParallelForBase<AsyncParallelForPass> { AsyncParallelForPass() = default; + AsyncParallelForPass(int numWorkerThreads) { + assert(numWorkerThreads >= 1); + numConcurrentAsyncExecute = numWorkerThreads; + } void runOnFunction() override; }; @@ -276,3 +280,8 @@ void AsyncParallelForPass::runOnFunction() { std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() { return std::make_unique<AsyncParallelForPass>(); } + +std::unique_ptr<OperationPass<FuncOp>> +mlir::createAsyncParallelForPass(int numWorkerThreads) { + return std::make_unique<AsyncParallelForPass>(numWorkerThreads); +} |