diff options
author | Thomas Raoux <thomasraoux@google.com> | 2020-12-03 23:44:36 -0800 |
---|---|---|
committer | Thomas Raoux <thomasraoux@google.com> | 2020-12-04 09:53:01 -0800 |
commit | 3e3e276d22ca6917a721c4173b00b37850d8020c (patch) | |
tree | 6778793fecb5c8aef8e4f6c6c441eb29a556cdd4 /mlir/test/lib/Transforms/TestVectorTransforms.cpp | |
parent | 840e651dc6d7fe652667eb8b4d04ef4daf4769df (diff) | |
download | llvm-project-3e3e276d22ca6917a721c4173b00b37850d8020c.tar.gz |
[mlir][vector][NFC] Change UnrollVectorPattern to not be statically dependent on an op type
Make UnrollVectorPattern inherit from RewritePattern instead of
OpRewritePattern so that we don't need to create many patterns when applying to
many different type of ops. Since we may want to apply the pattern to all
arithmetic op, it is more convenient to filter dynamically.
Differential Revision: https://reviews.llvm.org/D92635
Diffstat (limited to 'mlir/test/lib/Transforms/TestVectorTransforms.cpp')
-rw-r--r-- | mlir/test/lib/Transforms/TestVectorTransforms.cpp | 52 |
1 files changed, 37 insertions, 15 deletions
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 602bf8148cd8..99c336ef0565 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -27,14 +27,22 @@ struct TestVectorToVectorConversion void runOnFunction() override { OwningRewritePatternList patterns; auto *ctx = &getContext(); - patterns.insert<UnrollVectorPattern<AddFOp>>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2})); - patterns.insert<UnrollVectorPattern<vector::ContractionOp>>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2})); + patterns.insert<UnrollVectorPattern>( + ctx, UnrollVectorOptions().setNativeShapeFn(getShape)); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } + +private: + // Return the target shape based on op type. + static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) { + if (isa<AddFOp>(op)) + return SmallVector<int64_t, 4>(2, 2); + if (isa<vector::ContractionOp>(op)) + return SmallVector<int64_t, 4>(3, 2); + return llvm::None; + } }; struct TestVectorSlicesConversion @@ -120,8 +128,11 @@ struct TestVectorUnrollingPatterns void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert<UnrollVectorPattern<AddFOp>>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2})); + patterns.insert<UnrollVectorPattern>( + ctx, UnrollVectorOptions() + .setNativeShape(ArrayRef<int64_t>{2, 2}) + .setFilterConstraint( + [](Operation *op) { return success(isa<AddFOp>(op)); })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = @@ -137,12 +148,19 @@ struct TestVectorUnrollingPatterns } return nativeShape; }; - patterns.insert<UnrollVectorPattern<vector::ContractionOp>>( - ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn)); + patterns.insert<UnrollVectorPattern>( + ctx, UnrollVectorOptions() + .setNativeShapeFn(nativeShapeFn) + .setFilterConstraint([](Operation *op) { + return success(isa<ContractionOp>(op)); + })); } else { - patterns.insert<UnrollVectorPattern<vector::ContractionOp>>( - ctx, - UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2})); + patterns.insert<UnrollVectorPattern>( + ctx, UnrollVectorOptions() + .setNativeShape(ArrayRef<int64_t>{2, 2, 2}) + .setFilterConstraint([](Operation *op) { + return success(isa<ContractionOp>(op)); + })); } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); @@ -273,10 +291,14 @@ struct TestVectorTransferUnrollingPatterns void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2})); - patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2})); + patterns.insert<UnrollVectorPattern>( + ctx, + UnrollVectorOptions() + .setNativeShape(ArrayRef<int64_t>{2, 2}) + .setFilterConstraint([](Operation *op) { + return success( + isa<vector::TransferReadOp, vector::TransferWriteOp>(op)); + })); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); |