diff options
author | Nicolas Vasilache <nicolas.vasilache@gmail.com> | 2020-12-04 13:51:30 +0000 |
---|---|---|
committer | Nicolas Vasilache <nicolas.vasilache@gmail.com> | 2020-12-04 14:00:54 +0000 |
commit | a1cd559ce500d18eb15750ac776e7e73b3819832 (patch) | |
tree | 35da9a7742de81dd5dcc949a6c56064157e3fd3e /mlir | |
parent | 16b1f6e3858b7082ae9f8eea65aff8a04c692099 (diff) | |
download | llvm-project-a1cd559ce500d18eb15750ac776e7e73b3819832.tar.gz |
[mlir][Linalg] Properly use distribution options.
Let tiling to scf.for actually use the distribution method.
For now only Cyclic is supported.
Differential Revision: https://reviews.llvm.org/D92653
Diffstat (limited to 'mlir')
4 files changed, 18 insertions, 7 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index b37a14f0eb7a..90c6a0374e94 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -389,6 +389,11 @@ OwningRewritePatternList getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); struct LinalgBaseTilingPattern : public RewritePattern { + // Entry point to match any LinalgOp OpInterface. + LinalgBaseTilingPattern(LinalgTilingOptions options, + LinalgMarker marker = LinalgMarker(), + PatternBenefit benefit = 1); + // Entry point to match a specific Linalg op. LinalgBaseTilingPattern(StringRef opName, MLIRContext *context, LinalgTilingOptions options, LinalgMarker marker = LinalgMarker(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 97c3dafe57a8..804ae6681f8c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -111,6 +111,11 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( : RewritePattern(opName, {}, benefit, context), marker(marker), options(options) {} +mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( + LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit) + : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker), + options(options) {} + LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( Operation *op, PatternRewriter &rewriter, SmallVectorImpl<Value> &tensorResults) const { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 8e60312bf4fd..f44bb6769e61 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -128,12 +128,12 @@ void GenerateLoopNest<scf::ForOp>::doit( ArrayRef<Attribute> iteratorTypes, function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn, Optional<LinalgLoopDistributionOptions> distributionOptions) { - // Create procInfo so it dominate loops, if appropriate. + // Create procInfo so it dominates loops, if appropriate. OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); Location loc = edsc::ScopedContext::getLocation(); SmallVector<ProcInfo, 2> procInfo; if (distributionOptions.hasValue()) - procInfo = distributionOptions->procInfo(builder, loc, ArrayRef<Range>{}); + procInfo = distributionOptions->procInfo(builder, loc, loopRanges); SmallVector<Value, 4> lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); @@ -143,11 +143,12 @@ void GenerateLoopNest<scf::ForOp>::doit( if (!distributionOptions.hasValue() || loopNest.loops.empty()) return; - // TODO: support distributionMethod, which is currently ignored. + // Only supports cyclic distribution for now. for (auto it : llvm::zip(loopNest.loops, procInfo, distributionOptions->distributionMethod)) - mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId, - std::get<1>(it).nprocs); + if (std::get<2>(it) == DistributionMethod::Cyclic) + mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId, + std::get<1>(it).nprocs); } /// Specialization to build affine "for" nest. diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index 9e3efcf41664..c2b4c7b9c821 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -415,8 +415,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context, { LinalgLoopDistributionOptions cyclicNprocsEqNiters; - cyclicNprocsEqNiters.distributionMethod.resize( - 2, DistributionMethod::CyclicNumProcsEqNumIters); + cyclicNprocsEqNiters.distributionMethod.resize(2, + DistributionMethod::Cyclic); cyclicNprocsEqNiters.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; patterns.insert<LinalgTilingPattern<MatmulOp>>( |