aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEugene Zhulenev <ezhulenev@google.com>2020-12-08 04:35:27 -0800
committerEugene Zhulenev <ezhulenev@google.com>2020-12-08 10:30:14 -0800
commit94e645f9cce8fba26b4aec069103794f1779065f (patch)
treebf3d284bc8da54804d82bc225d7a143d33eec95d
parent4fede8bc8a015477f2a8feeb30a1d2a2e155106d (diff)
downloadllvm-project-94e645f9cce8fba26b4aec069103794f1779065f.tar.gz
[mlir] Async: Add numWorkerThreads argument to createAsyncParallelForPass
Add an option to pass the number of worker threads to select the number of async regions for parallel for transformation. ``` std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass(int numWorkerThreads); ``` Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D92835
-rw-r--r--mlir/include/mlir/Dialect/Async/Passes.h3
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp9
2 files changed, 12 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h
index 9716bde76593..ab5abdc28611 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.h
+++ b/mlir/include/mlir/Dialect/Async/Passes.h
@@ -19,6 +19,9 @@ namespace mlir {
std::unique_ptr<OperationPass<FuncOp>> createAsyncParallelForPass();
+std::unique_ptr<OperationPass<FuncOp>>
+createAsyncParallelForPass(int numWorkerThreads);
+
std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingPass();
std::unique_ptr<OperationPass<FuncOp>> createAsyncRefCountingOptimizationPass();
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index c6508610c796..d6553974bc38 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -96,6 +96,10 @@ private:
struct AsyncParallelForPass
: public AsyncParallelForBase<AsyncParallelForPass> {
AsyncParallelForPass() = default;
+ AsyncParallelForPass(int numWorkerThreads) {
+ assert(numWorkerThreads >= 1);
+ numConcurrentAsyncExecute = numWorkerThreads;
+ }
void runOnFunction() override;
};
@@ -276,3 +280,8 @@ void AsyncParallelForPass::runOnFunction() {
std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
return std::make_unique<AsyncParallelForPass>();
}
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createAsyncParallelForPass(int numWorkerThreads) {
+ return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
+}