diff options
author | Christian Sigg <csigg@google.com> | 2020-12-04 19:12:02 +0100 |
---|---|---|
committer | Christian Sigg <csigg@google.com> | 2020-12-08 16:44:51 +0100 |
commit | 02c9050155dff70497b3423ae95ed7d2ab7675a8 (patch) | |
tree | 0cfa696dfe4a1056f1987606537a5598bae213ff /mlir | |
parent | 2812c1515627904e31605bbd4f25a887a1f8eb12 (diff) | |
download | llvm-project-02c9050155dff70497b3423ae95ed7d2ab7675a8.tar.gz |
[mlir] Tighten access of RewritePattern methods.
In RewritePattern, only expose `matchAndRewrite` as a public function. `match` can be protected (but needs to be protected because we want to call it from an override of `matchAndRewrite`). `rewrite` can be private.
For classes deriving from RewritePattern, all 3 functions can be private.
Side note: I didn't understand the need for the `using RewritePattern::matchAndRewrite` in derived classes, and started poking around. They are gone now, and I think the result is (only very slightly) cleaner.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D92670
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h | 18 | ||||
-rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 28 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 61 |
3 files changed, 56 insertions, 51 deletions
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index bf41f29749de..5b605c165be6 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -571,11 +571,9 @@ public: &typeConverter.getContext(), typeConverter, benefit) {} - /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(Operation *op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast<SourceOp>(op), operands, rewriter); - } +private: + /// Wrappers around the ConversionPattern methods that pass the derived op + /// type. LogicalResult match(Operation *op) const final { return match(cast<SourceOp>(op)); } @@ -584,6 +582,10 @@ public: ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); } + void rewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast<SourceOp>(op), operands, rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -603,10 +605,6 @@ public: } return failure(); } - -private: - using ConvertToLLVMPattern::match; - using ConvertToLLVMPattern::matchAndRewrite; }; namespace LLVM { @@ -636,6 +634,7 @@ public: using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>; +private: /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult @@ -655,6 +654,7 @@ public: using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; +private: LogicalResult matchAndRewrite(SourceOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 0bbb2216ee7b..1739cfa4a80c 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -156,17 +156,6 @@ class RewritePattern : public Pattern { public: virtual ~RewritePattern() {} - /// Rewrite the IR rooted at the specified operation with the result of - /// this pattern, generating any new operations with the specified - /// builder. If an unexpected error is encountered (an internal - /// compiler error), it is emitted through the normal MLIR diagnostic - /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; - - /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). - virtual LogicalResult match(Operation *op) const; - /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). If successful, this /// function will automatically perform the rewrite. @@ -183,6 +172,18 @@ protected: /// Inherit the base constructors from `Pattern`. using Pattern::Pattern; + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). + virtual LogicalResult match(Operation *op) const; + +private: + /// Rewrite the IR rooted at the specified operation with the result of + /// this pattern, generating any new operations with the specified + /// builder. If an unexpected error is encountered (an internal + /// compiler error), it is emitted through the normal MLIR diagnostic + /// hooks and the IR is left in a valid state. + virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; + /// An anchor for the virtual table. virtual void anchor(); }; @@ -190,12 +191,15 @@ protected: /// OpRewritePattern is a wrapper around RewritePattern that allows for /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. -template <typename SourceOp> struct OpRewritePattern : public RewritePattern { +template <typename SourceOp> +class OpRewritePattern : public RewritePattern { +public: /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) : RewritePattern(SourceOp::getOperationName(), benefit, context) {} +private: /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, PatternRewriter &rewriter) const final { rewrite(cast<SourceOp>(op), rewriter); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index e02cf8fe4c0a..ecbb653f7ed9 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -313,6 +313,30 @@ private: /// patterns of this type can only be used with the 'apply*' methods below. class ConversionPattern : public RewritePattern { public: + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + +protected: + /// See `RewritePattern::RewritePattern` for information on the other + /// available constructors. + using RewritePattern::RewritePattern; + /// Construct a conversion pattern that matches an operation with the given + /// root name. This constructor allows for providing a type converter to use + /// within the pattern. + ConversionPattern(StringRef rootName, PatternBenefit benefit, + TypeConverter &typeConverter, MLIRContext *ctx) + : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} + /// Construct a conversion pattern that matches any operation type. This + /// constructor allows for providing a type converter to use within the + /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, + MatchAnyOpTypeTag tag) + : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} + +private: /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of the rewritten /// operand values that are passed to `op`, `rewriter` can be used to emit the @@ -323,6 +347,10 @@ public: llvm_unreachable("unimplemented rewrite"); } + void rewrite(Operation *op, PatternRewriter &rewriter) const final { + llvm_unreachable("never called"); + } + /// Hook for derived classes to implement combined matching and rewriting. virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, @@ -337,42 +365,17 @@ public: LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; - /// Return the type converter held by this pattern, or nullptr if the pattern - /// does not require type conversion. - TypeConverter *getTypeConverter() const { return typeConverter; } - -protected: - /// See `RewritePattern::RewritePattern` for information on the other - /// available constructors. - using RewritePattern::RewritePattern; - /// Construct a conversion pattern that matches an operation with the given - /// root name. This constructor allows for providing a type converter to use - /// within the pattern. - ConversionPattern(StringRef rootName, PatternBenefit benefit, - TypeConverter &typeConverter, MLIRContext *ctx) - : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} - /// Construct a conversion pattern that matches any operation type. This - /// constructor allows for providing a type converter to use within the - /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" - /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should - /// always be supplied here. - ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, - MatchAnyOpTypeTag tag) - : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} - protected: /// An optional type converter for use by this pattern. TypeConverter *typeConverter = nullptr; - -private: - using RewritePattern::rewrite; }; /// OpConversionPattern is a wrapper around ConversionPattern that allows for /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. template <typename SourceOp> -struct OpConversionPattern : public ConversionPattern { +class OpConversionPattern : public ConversionPattern { +public: OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -380,6 +383,7 @@ struct OpConversionPattern : public ConversionPattern { : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter, context) {} +private: /// Wrappers around the ConversionPattern methods that pass the derived op /// type. void rewrite(Operation *op, ArrayRef<Value> operands, @@ -409,9 +413,6 @@ struct OpConversionPattern : public ConversionPattern { rewrite(op, operands, rewriter); return success(); } - -private: - using ConversionPattern::matchAndRewrite; }; /// Add a pattern to the given pattern list to convert the signature of a FuncOp |