diff options
Diffstat (limited to 'mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td')
-rw-r--r-- | mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td index c57ad8c8d17c..43c670a8582e 100644 --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -1,4 +1,5 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" def AllInputShapesEq : Constraint<CPred< [{ llvm::all_of($0, [&](mlir::Value val) { @@ -6,8 +7,16 @@ def AllInputShapesEq : Constraint<CPred< [{ }) }]>>; +def HasSingleElement : Constraint<CPred< [{ + $0.size() == 1 +}]>>; + // Canonicalization patterns. +def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args), + (replaceWithValue $args), + [(HasSingleElement $args)]>; + def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x), (Shape_ConstWitnessOp ConstBoolAttrTrue)>; @@ -23,3 +32,5 @@ def SizeToIndexToSizeCanonicalization : Pat< (Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)), (replaceWithValue $arg)>; +def TensorCastConstShape : Pat < + (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>; |