aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorNicolas Vasilache <nicolas.vasilache@gmail.com>2020-12-04 13:51:30 +0000
committerNicolas Vasilache <nicolas.vasilache@gmail.com>2020-12-04 14:00:54 +0000
commita1cd559ce500d18eb15750ac776e7e73b3819832 (patch)
tree35da9a7742de81dd5dcc949a6c56064157e3fd3e /mlir
parent16b1f6e3858b7082ae9f8eea65aff8a04c692099 (diff)
downloadllvm-project-a1cd559ce500d18eb15750ac776e7e73b3819832.tar.gz
[mlir][Linalg] Properly use distribution options.
Let tiling to scf.for actually use the distribution method. For now only Cyclic is supported. Differential Revision: https://reviews.llvm.org/D92653
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp11
-rw-r--r--mlir/test/lib/Transforms/TestLinalgTransforms.cpp4
4 files changed, 18 insertions, 7 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b37a14f0eb7a..90c6a0374e94 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -389,6 +389,11 @@ OwningRewritePatternList
getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
struct LinalgBaseTilingPattern : public RewritePattern {
+ // Entry point to match any LinalgOp OpInterface.
+ LinalgBaseTilingPattern(LinalgTilingOptions options,
+ LinalgMarker marker = LinalgMarker(),
+ PatternBenefit benefit = 1);
+ // Entry point to match a specific Linalg op.
LinalgBaseTilingPattern(StringRef opName, MLIRContext *context,
LinalgTilingOptions options,
LinalgMarker marker = LinalgMarker(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 97c3dafe57a8..804ae6681f8c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -111,6 +111,11 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
: RewritePattern(opName, {}, benefit, context), marker(marker),
options(options) {}
+mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
+ LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit)
+ : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker),
+ options(options) {}
+
LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<Value> &tensorResults) const {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 8e60312bf4fd..f44bb6769e61 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -128,12 +128,12 @@ void GenerateLoopNest<scf::ForOp>::doit(
ArrayRef<Attribute> iteratorTypes,
function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
Optional<LinalgLoopDistributionOptions> distributionOptions) {
- // Create procInfo so it dominate loops, if appropriate.
+ // Create procInfo so it dominates loops, if appropriate.
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
Location loc = edsc::ScopedContext::getLocation();
SmallVector<ProcInfo, 2> procInfo;
if (distributionOptions.hasValue())
- procInfo = distributionOptions->procInfo(builder, loc, ArrayRef<Range>{});
+ procInfo = distributionOptions->procInfo(builder, loc, loopRanges);
SmallVector<Value, 4> lbs, ubs, steps;
unpackRanges(loopRanges, lbs, ubs, steps);
@@ -143,11 +143,12 @@ void GenerateLoopNest<scf::ForOp>::doit(
if (!distributionOptions.hasValue() || loopNest.loops.empty())
return;
- // TODO: support distributionMethod, which is currently ignored.
+ // Only supports cyclic distribution for now.
for (auto it : llvm::zip(loopNest.loops, procInfo,
distributionOptions->distributionMethod))
- mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
- std::get<1>(it).nprocs);
+ if (std::get<2>(it) == DistributionMethod::Cyclic)
+ mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
+ std::get<1>(it).nprocs);
}
/// Specialization to build affine "for" nest.
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 9e3efcf41664..c2b4c7b9c821 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -415,8 +415,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
{
LinalgLoopDistributionOptions cyclicNprocsEqNiters;
- cyclicNprocsEqNiters.distributionMethod.resize(
- 2, DistributionMethod::CyclicNumProcsEqNumIters);
+ cyclicNprocsEqNiters.distributionMethod.resize(2,
+ DistributionMethod::Cyclic);
cyclicNprocsEqNiters.procInfo =
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
patterns.insert<LinalgTilingPattern<MatmulOp>>(