aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorBenjamin Kramer <benny.kra@googlemail.com>2020-12-08 15:37:32 +0100
committerBenjamin Kramer <benny.kra@googlemail.com>2020-12-08 17:07:24 +0100
commit5844bc540cafb4330e7625b83371f1dab90528c3 (patch)
tree46a968b4a538e3e171de972c1fea4d7bbe83e5b1 /mlir
parentfebe75032f6f8322cce1dcbba11a44559aaa14e3 (diff)
downloadllvm-project-5844bc540cafb4330e7625b83371f1dab90528c3.tar.gz
[mlir][Shape] Canonicalize assume_all with one input and tensor_cast of constant_shape
This allows simplifying some more complicated shape expressions Differential Revision: https://reviews.llvm.org/D92843
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td2
-rw-r--r--mlir/lib/Dialect/Shape/IR/Shape.cpp11
-rw-r--r--mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td11
-rw-r--r--mlir/test/Dialect/Shape/canonicalize.mlir38
4 files changed, 57 insertions, 5 deletions
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 52768e49001d..552de7e78f91 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -105,6 +105,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> {
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def Shape_ConstSizeOp : Shape_Op<"const_size", [
@@ -630,6 +631,7 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]>
let assemblyFormat = "$inputs attr-dict";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
let verifier = [{ return ::verify(*this); }];
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index fe57f7d7a52e..acb35b916f7e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -271,6 +271,12 @@ void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
+
+void AssumingAllOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<AssumingAllOneOp>(context);
+}
+
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
// Iterate in reverse to first handle all constant operands. They are
// guaranteed to be the tail of the inputs because this is commutative.
@@ -394,6 +400,11 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
+void ConstShapeOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *context) {
+ patterns.insert<TensorCastConstShape>(context);
+}
+
//===----------------------------------------------------------------------===//
// CstrBroadcastableOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
index c57ad8c8d17c..43c670a8582e 100644
--- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
+++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td
@@ -1,4 +1,5 @@
include "mlir/Dialect/Shape/IR/ShapeOps.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
def AllInputShapesEq : Constraint<CPred< [{
llvm::all_of($0, [&](mlir::Value val) {
@@ -6,8 +7,16 @@ def AllInputShapesEq : Constraint<CPred< [{
})
}]>>;
+def HasSingleElement : Constraint<CPred< [{
+ $0.size() == 1
+}]>>;
+
// Canonicalization patterns.
+def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
+ (replaceWithValue $args),
+ [(HasSingleElement $args)]>;
+
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
(Shape_ConstWitnessOp ConstBoolAttrTrue)>;
@@ -23,3 +32,5 @@ def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
+def TensorCastConstShape : Pat <
+ (TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 56a6ef74f54e..9cb01da75901 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -427,20 +427,23 @@ func @f() {
// -----
-// assuming_all should not be removed if not all witnesses are statically passing.
+// assuming_all should not be removed if more than one witness is not
+// statically passing
//
// Additionally check that the attribute is moved to the end as this op is
// commutative.
// CHECK-LABEL: func @f
func @f() {
- // CHECK-NEXT: %[[UNKNOWN:.*]] = "test.source"
- // CHECK-NEXT: shape.assuming_all %[[UNKNOWN]]
+ // CHECK-NEXT: %[[UNKNOWN1:.*]] = "test.source"
+ // CHECK-NEXT: %[[UNKNOWN2:.*]] = "test.source"
+ // CHECK-NEXT: shape.assuming_all %[[UNKNOWN1]], %[[UNKNOWN2]]
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.const_witness true
%1 = "test.source"() : () -> !shape.witness
- %2 = shape.assuming_all %0, %1
- "consume.witness"(%2) : (!shape.witness) -> ()
+ %2 = "test.source"() : () -> !shape.witness
+ %3 = shape.assuming_all %0, %1, %2
+ "consume.witness"(%3) : (!shape.witness) -> ()
return
}
@@ -854,3 +857,28 @@ func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex>
%casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
return %casted : tensor<?xindex>
}
+
+// -----
+
+// Fold assuming_all with a single input
+// CHECK-LABEL: @fold_assuming_all_single_element
+func @fold_assuming_all_single_element(%arg: tensor<?xindex>) {
+ // CHECK-NOT: assuming_all
+ %0 = "test.source"() : () -> (!shape.witness)
+ %1 = shape.assuming_all %0
+ "consume.witness"(%1) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+
+// Fold tensor_cast of a const_shape to const_shape
+// CHECK-LABEL: @fold_tensor_cast_of_const_shape
+func @fold_tensor_cast_of_const_shape(%arg: tensor<?xindex>) {
+ // CHECK-NOT: tensor_cast
+ %0 = shape.const_shape [2] : tensor<?xindex>
+ %1 = tensor_cast %0 : tensor<?xindex> to tensor<1xindex>
+ %2 = shape.cstr_broadcastable %1, %0 : tensor<1xindex>, tensor<?xindex>
+ "consume.witness"(%2) : (!shape.witness) -> ()
+ return
+}