aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td')
-rw-r--r--mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td11
1 files changed, 11 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index c57ad8c8d17c..43c670a8582e 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -1,4 +1,5 @@
include "mlir/Dialect/Shape/IR/ShapeOps.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
def AllInputShapesEq : Constraint<CPred< [{
llvm::all_of($0, [&](mlir::Value val) {
@@ -6,8 +7,16 @@ def AllInputShapesEq : Constraint<CPred< [{
})
}]>>;
+def HasSingleElement : Constraint<CPred< [{
+ $0.size() == 1
+}]>>;
+
// Canonicalization patterns.
+def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
+ (replaceWithValue $args),
+ [(HasSingleElement $args)]>;
+
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
(Shape_ConstWitnessOp ConstBoolAttrTrue)>;
@@ -23,3 +32,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
+def TensorCastConstShape : Pat <
+ (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;