aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorThomas Raoux <thomasraoux@google.com>2020-12-02 20:45:26 -0800
committerThomas Raoux <thomasraoux@google.com>2020-12-03 15:31:13 -0800
commitc503dc1b8a52946e4daefa1a266e74a102382971 (patch)
treefd0b1800e83b79648abb2fd960f229de91a3e393 /mlir
parentbe162f4c0e8563c8de510121435281ae628c8654 (diff)
downloadllvm-project-c503dc1b8a52946e4daefa1a266e74a102382971.tar.gz
[mlir][linalg] Add vectorization for element-wise linalg ops
Add support for vectorization for linalg.generic representing element-wise ops. Those are converted to transfer_read + vector ops + transfer_write. Also re-organize the vectorization tests to be together. Implementation derived from the work of @burmako, @agrue and @fedelebron. Differential Revision: https://reviews.llvm.org/D92540
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/EDSC/Builders.h1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp248
-rw-r--r--mlir/lib/EDSC/Builders.cpp3
-rw-r--r--mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir38
-rw-r--r--mlir/test/Dialect/Linalg/transform-patterns.mlir95
-rw-r--r--mlir/test/Dialect/Linalg/vectorization.mlir210
-rw-r--r--mlir/test/lib/Transforms/TestLinalgTransforms.cpp19
7 files changed, 426 insertions, 188 deletions
diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h
index 70c948d99cda..83b6634bf8e2 100644
--- a/mlir/include/mlir/EDSC/Builders.h
+++ b/mlir/include/mlir/EDSC/Builders.h
@@ -30,6 +30,7 @@ namespace edsc {
/// setting and restoring of insertion points.
class ScopedContext {
public:
+ ScopedContext(OpBuilder &b);
ScopedContext(OpBuilder &b, Location location);
/// Sets the insertion point of the builder to 'newInsertPt' for the duration
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8860674ef847..a28b90b1d95c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -84,6 +84,195 @@ static LogicalResult isContraction(Operation *op) {
hasMultiplyAddBody(genericOp.region()));
}
+static bool hasOnlyScalarElementwiseOp(Region &r) {
+ if (!llvm::hasSingleElement(r))
+ return false;
+ for (Operation &op : r.front()) {
+ if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+ op.hasTrait<OpTrait::ElementwiseMappable>()) ||
+ llvm::any_of(op.getResultTypes(),
+ [](Type type) { return !type.isIntOrIndexOrFloat(); }))
+ return false;
+ }
+ return true;
+}
+
+// Return true if the op is an element-wise linalg op.
+static bool isElementwise(Operation *op) {
+ auto genericOp = dyn_cast<linalg::GenericOp>(op);
+ if (!genericOp)
+ return false;
+ if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
+ return false;
+ // TODO: relax the restrictions on indexing map.
+ for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
+ if (!genericOp.getOutputIndexingMap(i).isIdentity())
+ return false;
+ }
+ // Currently limit the input indexing map to minor identity as other
+ // permutations might require adding transpose ops to convert the vector read
+ // to the right shape.
+ for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
+ if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
+ return false;
+ }
+ return hasOnlyScalarElementwiseOp(genericOp.getRegion());
+}
+
+static VectorType extractVectorTypeFromScalarView(Value v) {
+ MemRefType mt = v.getType().cast<MemRefType>();
+ return mt.getShape().empty()
+ ? VectorType()
+ : VectorType::get(mt.getShape(), mt.getElementType());
+}
+
+static Value transferReadVector(OpBuilder &builder, Value memref) {
+ edsc::ScopedContext scope(builder);
+ auto memrefType = memref.getType().cast<MemRefType>();
+ if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
+ SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+ return vector_transfer_read(vectorType, memref, indices);
+ }
+ return std_load(memref);
+}
+
+static void transferWriteVector(OpBuilder &builder, Value value, Value memref) {
+ edsc::ScopedContext scope(builder);
+ auto memrefType = memref.getType().cast<MemRefType>();
+ if (VectorType vectorType = extractVectorTypeFromScalarView(memref)) {
+ SmallVector<Value, 4> indices(memrefType.getRank(), std_constant_index(0));
+ if (vectorType != value.getType())
+ value = vector_broadcast(vectorType, value);
+ vector_transfer_write(value, memref, indices);
+ } else {
+ std_store(value, memref);
+ }
+}
+
+namespace {
+// Transforms scalar operations into their vectorized counterparts,
+// while using the provided generic op to map:
+// * Its arguments to transfer reads from the views of the generic op.
+// * linalg.yield ops to transfer writes to the views of the generic op.
+class GenericVectorizer {
+public:
+ GenericVectorizer(OpBuilder &builder, linalg::GenericOp generic)
+ : builder(builder), generic(generic) {}
+
+ // Takes a scalar operation and builds its vectorized counterpart or
+ // counterparts using the underlying builder.
+ // If operands of the scalar operation are referring to previously vectorized
+ // operations, then in their vectorized form these operands will be referring
+ // to previous vectorization results.
+ void vectorize(Operation &scalarOp) {
+ auto yieldOp = dyn_cast<linalg::YieldOp>(scalarOp);
+ if (yieldOp) {
+ for (auto outputAndMemref :
+ llvm::zip(yieldOp.values(), generic.getOutputBuffers())) {
+ Value vectorValue = vectorize(std::get<0>(outputAndMemref));
+ transferWriteVector(builder, vectorValue, std::get<1>(outputAndMemref));
+ }
+ return;
+ }
+ Operation *vectorOp = uncachedVectorize(scalarOp);
+ assert(scalarOp.getNumResults() == vectorOp->getNumResults());
+ for (auto result :
+ llvm::zip(scalarOp.getResults(), vectorOp->getResults())) {
+ valueCache[std::get<0>(result)] = std::get<1>(result);
+ }
+ }
+
+private:
+ // Transforms a scalar value into its vectorized counterpart, recursively
+ // vectorizing operations as necessary using the underlying builder.
+ // Keeps track of previously vectorized values and reuses vectorization
+ // results if these values come up again.
+ Value vectorize(Value scalarValue) {
+ // Don't vectorize values coming from outside the region.
+ if (scalarValue.getParentRegion() != &generic.region())
+ return scalarValue;
+ auto vectorValueIt = valueCache.find(scalarValue);
+ if (vectorValueIt != valueCache.end())
+ return vectorValueIt->second;
+
+ // If the value is from the region but not in the cache it means it is a
+ // block argument.
+ auto scalarArg = scalarValue.cast<BlockArgument>();
+ assert(scalarArg.getOwner() == &generic.region().front());
+ Value vector_arg =
+ generic.getInputsAndOutputBuffers()[scalarArg.getArgNumber()];
+ Value vectorResult = transferReadVector(builder, vector_arg);
+ valueCache[scalarArg] = vectorResult;
+ return vectorResult;
+ }
+
+ // Return the largest shape of all the given values. Return an empty
+ // SmallVector if there are no vector value.
+ static SmallVector<int64_t, 4> getLargestShape(ArrayRef<Value> values) {
+ SmallVector<int64_t, 4> largestShape;
+ int64_t maxSize = 1;
+ for (Value value : values) {
+ auto vecType = value.getType().dyn_cast<VectorType>();
+ if (!vecType)
+ continue;
+ if (maxSize < vecType.getNumElements()) {
+ largestShape.assign(vecType.getShape().begin(),
+ vecType.getShape().end());
+ }
+ }
+ return largestShape;
+ }
+
+ // If the value's type doesn't have the given shape broadcast it.
+ Value broadcastIfNeeded(Value value, ArrayRef<int64_t> shape) {
+ auto vecType = value.getType().dyn_cast<VectorType>();
+ if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
+ return value;
+ auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
+ : value.getType());
+ return builder.create<vector::BroadcastOp>(
+ builder.getInsertionPoint()->getLoc(), newVecType, value);
+ }
+
+ // Takes a scalar operation and builds its vectorized counterpart or
+ // counterparts using underlying builder without involving any caches.
+ Operation *uncachedVectorize(Operation &base_scalarOp) {
+ SmallVector<Value, 4> vectorizedOperands;
+ for (Value operand : base_scalarOp.getOperands()) {
+ vectorizedOperands.push_back(vectorize(operand));
+ }
+ SmallVector<int64_t, 4> shape = getLargestShape(vectorizedOperands);
+ for (Value &operand : vectorizedOperands)
+ operand = broadcastIfNeeded(operand, shape);
+ OperationState state(base_scalarOp.getLoc(), base_scalarOp.getName());
+ state.addAttributes(base_scalarOp.getAttrs());
+ state.addOperands(vectorizedOperands);
+ if (shape.empty()) {
+ state.addTypes(base_scalarOp.getResultTypes());
+ } else {
+ SmallVector<VectorType, 4> vectorizedTypes;
+ for (auto Type : base_scalarOp.getResultTypes())
+ vectorizedTypes.push_back(VectorType::get(shape, Type));
+ state.addTypes(vectorizedTypes);
+ }
+ return builder.createOperation(state);
+ }
+
+ OpBuilder &builder;
+ linalg::GenericOp generic;
+ llvm::DenseMap<Value, Value> valueCache;
+};
+} // namespace
+
+// Replaces elementwise linalg.generic ops with their bodies with scalar
+// operations from these bodies promoted to vector operations.
+static void vectorizeElementwise(linalg::GenericOp op, OpBuilder &builder) {
+ GenericVectorizer vectorizer(builder, op);
+ for (Operation &scalarOp : op.region().front()) {
+ vectorizer.vectorize(scalarOp);
+ }
+}
+
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
auto linalgOp = cast<linalg::LinalgOp>(op);
// All types must be static shape to go to vector.
@@ -96,7 +285,8 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
if (isa<linalg::FillOp, linalg::CopyOp>(op))
return success();
-
+ if (isElementwise(op))
+ return success();
return isContraction(op);
}
@@ -108,28 +298,11 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
edsc::ScopedContext scope(builder, op->getLoc());
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
- auto extractVectorTypeFromScalarView = [](Value v) {
- MemRefType mt = v.getType().cast<MemRefType>();
- return mt.getShape().empty()
- ? VectorType()
- : VectorType::get(mt.getShape(), mt.getElementType());
- };
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
- Value viewOutput = fillOp.output();
- if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
- auto vecType =
- VectorType::get(fillOp.getOutputBufferType(0).getShape(),
- fillOp.getOutputBufferType(0).getElementType());
- Value vector = vector_broadcast(vecType, fillOp.value());
- Value zero = std_constant_index(0);
- SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
- vector_transfer_write(vector, viewOutput, indicesOutput);
- } else {
- std_store(fillOp.value(), viewOutput);
- }
+ transferWriteVector(builder, fillOp.value(), fillOp.output());
return;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
@@ -138,36 +311,19 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
<< "Rewrite linalg.copy as vector.transfer_read + "
"vector.transfer_write: "
<< *op);
- Value zero = std_constant_index(0);
- Value viewInput = copyOp.input();
- Value viewOutput = copyOp.output();
- Value vector;
- if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) {
- SmallVector<Value, 4> indicesInput(inputType.getRank(), zero);
- if (copyOp.inputPermutation())
- vector = vector_transfer_read(
- extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput,
- copyOp.inputPermutation().getValue());
- else
- vector =
- vector_transfer_read(extractVectorTypeFromScalarView(viewInput),
- viewInput, indicesInput);
- } else {
- vector = std_load(viewInput).value;
- }
- if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
- SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
- if (copyOp.outputPermutation())
- vector_transfer_write(vector, viewOutput, indicesOutput,
- copyOp.outputPermutation().getValue());
- else
- vector_transfer_write(vector, viewOutput, indicesOutput);
- } else {
- std_store(vector, viewOutput);
- }
+ Value vector = transferReadVector(builder, copyOp.input());
+ transferWriteVector(builder, vector, copyOp.output());
return;
}
+ if (isElementwise(op)) {
+ LLVM_DEBUG(dbgs() << dbgPref
+ << "Rewrite linalg op as vector.transfer_read + "
+ "vector_op + vector.transfer_write: "
+ << *op);
+ return vectorizeElementwise(cast<linalg::GenericOp>(op), builder);
+ }
+
assert(succeeded(isContraction(op)) && "Expected contraction");
// Vectorize other ops as vector contraction.
diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp
index 54086c926373..21a6b922d91f 100644
--- a/mlir/lib/EDSC/Builders.cpp
+++ b/mlir/lib/EDSC/Builders.cpp
@@ -15,6 +15,9 @@
using namespace mlir;
using namespace mlir::edsc;
+mlir::edsc::ScopedContext::ScopedContext(OpBuilder &b)
+ : ScopedContext(b, b.getInsertionPoint()->getLoc()) {}
+
mlir::edsc::ScopedContext::ScopedContext(OpBuilder &b, Location location)
: builder(b), guard(builder), location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()) {
diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index 155247a53806..dbdf19341920 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -1,6 +1,5 @@
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
@@ -26,40 +25,3 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
//
// CHECK: linalg.copy
-
-// VECTOR-CONTRACTION-LABEL: contraction_dot
-func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
- // VECTOR-CONTRACTION: vector.contract
- // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
- linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
- outs(%C: memref<f32>)
- return
-}
-
-// VECTOR-CONTRACTION-LABEL: contraction_matvec
-func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
- // VECTOR-CONTRACTION: vector.contract
- // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
- linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
- outs(%C: memref<1584xf32>)
- return
-}
-
-// VECTOR-CONTRACTION-LABEL: contraction_matmul
-func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
- // VECTOR-CONTRACTION: vector.contract
- // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
- linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
- outs(%C: memref<1584x1584xf32>)
- return
-}
-
-// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul
-func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
- // VECTOR-CONTRACTION: vector.contract
- // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
- linalg.batch_matmul
- ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
- outs(%C: memref<1584x1584x1584xf32>)
- return
-}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 9bdc4ad54826..83cb16ba0e3e 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -5,9 +5,7 @@
// CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// Map corresponding to a 2D memory access where the stride along all dims are unknown.
// CHECK-DAG: #[[$STRIDED_2D:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
-// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)>
// CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)>
@@ -92,99 +90,6 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK: ins({{.*}}, {{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>, memref<?x?xf32, #[[$STRIDED_2D]]>)
// CHECK: outs({{.*}}: memref<?x?xf32, #[[$STRIDED_2D]]>)
-#matmul_trait = {
- args_in = 2,
- args_out = 1,
- indexing_maps = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
- ],
- iterator_types = ["parallel", "parallel", "reduction"],
- __internal_linalg_transform__ = "VECTORIZE"
-}
-func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
- %C: memref<8x32xf32>) {
- linalg.generic #matmul_trait
- ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
- outs(%C : memref<8x32xf32>) {
- ^bb(%a: f32, %b: f32, %c: f32) :
- %d = mulf %a, %b: f32
- %e = addf %c, %d: f32
- linalg.yield %e : f32
- }
- return
-}
-// CHECK-LABEL: func @vectorization_test
-// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
-// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
-// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
-// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
-
-func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
- %C: memref<8x32xi32>) {
- linalg.generic #matmul_trait
- ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
- outs(%C : memref<8x32xi32>) {
- ^bb(%a: i32, %b: i32, %c: i32) :
- %d = muli %a, %b: i32
- %e = addi %c, %d: i32
- linalg.yield %e : i32
- }
- return
-}
-// CHECK-LABEL: func @vectorization_test_integer
-// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
-// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
-// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
-// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
-// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
-
-func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
- %C: memref<8x32xf32>) {
- linalg.matmul { __internal_linalg_transform__ = "VECTORIZE"}
- ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
- outs(%C: memref<8x32xf32>)
- return
-}
-// CHECK-LABEL: func @vectorization_test_2
-// CHECK: vector.contract {{.*}} :
-// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-
-func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
- linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, f32
- return
-}
-// CHECK-LABEL: func @test_vectorize_fill
-// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
-// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
-
-func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
- linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref<f32>, f32
- return
-}
-// CHECK-LABEL: func @test_vectorize_fill
-// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
-// CHECK: store %[[V]], %[[M]][] : memref<f32>
-
-func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
- linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, memref<8x16xf32>
- return
-}
-// CHECK-LABEL: func @test_vectorize_copy
-// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
-// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
-
-func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
- linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<f32>, memref<f32>
- return
-}
-// CHECK-LABEL: func @test_vectorize_copy_scalar
-// CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
-// CHECK: store %[[V]], {{.*}} : memref<f32>
-
-
#matmul_accesses = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
new file mode 100644
index 000000000000..1c3533275e49
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -0,0 +1,210 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns | FileCheck %s
+
+// CHECK-DAG: #[[$mk:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$kn:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$mn:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contraction_dot
+func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
+ // CHECK: vector.contract
+ // CHECK-SAME: vector<1584xf32>, vector<1584xf32> into f32
+ linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
+ outs(%C: memref<f32>)
+ return
+}
+
+// CHECK-LABEL: contraction_matvec
+func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
+ // CHECK: vector.contract
+ // CHECK-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
+ linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
+ outs(%C: memref<1584xf32>)
+ return
+}
+
+// CHECK-LABEL: contraction_matmul
+func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
+ // CHECK: vector.contract
+ // CHECK-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
+ linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
+ outs(%C: memref<1584x1584xf32>)
+ return
+}
+
+// CHECK-LABEL: contraction_batch_matmul
+func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
+ // CHECK: vector.contract
+ // CHECK-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
+ linalg.batch_matmul
+ ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+ outs(%C: memref<1584x1584x1584xf32>)
+ return
+}
+
+#matmul_trait = {
+ args_in = 2,
+ args_out = 1,
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
+ %C: memref<8x32xf32>) {
+ linalg.generic #matmul_trait
+ ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
+ outs(%C : memref<8x32xf32>) {
+ ^bb(%a: f32, %b: f32, %c: f32) :
+ %d = mulf %a, %b: f32
+ %e = addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
+}
+// CHECK-LABEL: func @vectorization_test
+// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
+// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
+
+func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
+ %C: memref<8x32xi32>) {
+ linalg.generic #matmul_trait
+ ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
+ outs(%C : memref<8x32xi32>) {
+ ^bb(%a: i32, %b: i32, %c: i32) :
+ %d = muli %a, %b: i32
+ %e = addi %c, %d: i32
+ linalg.yield %e : i32
+ }
+ return
+}
+// CHECK-LABEL: func @vectorization_test_integer
+// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
+// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
+// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
+
+func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
+ %C: memref<8x32xf32>) {
+ linalg.matmul
+ ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
+ outs(%C: memref<8x32xf32>)
+ return
+}
+// CHECK-LABEL: func @vectorization_test_2
+// CHECK: vector.contract {{.*}} :
+// vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
+
+func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
+ linalg.fill(%A, %arg0) : memref<8x16xf32>, f32
+ return
+}
+// CHECK-LABEL: func @test_vectorize_fill
+// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
+ linalg.fill(%A, %arg0) : memref<f32>, f32
+ return
+}
+// CHECK-LABEL: func @test_vectorize_fill
+// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[V:.*]]: f32)
+// CHECK: store %[[V]], %[[M]][] : memref<f32>
+
+func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+ linalg.copy(%A, %B) : memref<8x16xf32>, memref<8x16xf32>
+ return
+}
+// CHECK-LABEL: func @test_vectorize_copy
+// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+ linalg.copy(%A, %B) : memref<f32>, memref<f32>
+ return
+}
+// CHECK-LABEL: func @test_vectorize_copy_scalar
+// CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
+// CHECK: store %[[V]], {{.*}} : memref<f32>
+
+func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
+ %arg2: memref<256xf32>, %i: f32) {
+ %c1_f32 = constant 1.0 : f32
+ linalg.generic {
+ args_in = 0 : i64,
+ args_out = 10 : i64,
+ indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg1, %arg2: memref<4x256xf32>, memref<256xf32>)
+ outs(
+ %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0 :
+ memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
+ memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>, memref<4x256xf32>,
+ memref<4x256xf32>, memref<4x256xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32,
+ %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32,
+ %arg14 : f32):
+ %6 = addf %arg4, %arg6 : f32
+ %7 = cmpf "ogt", %arg3, %arg6 : f32
+ %8 = constant 2.0 : f32
+ %9 = divf %arg5, %i : f32
+ %10 = exp2 %arg5 : f32
+ %11 = mulf %arg5, %8 : f32
+ %12 = rsqrt %arg5 : f32
+ %13 = select %7, %arg5, %arg6 : f32
+ %14 = subf %arg5, %arg6 : f32
+ %15 = tanh %arg5 : f32
+ linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
+ f32, f32, f32, f32, f32, f32, f32, f32
+ }
+ return
+}
+
+// CHECK-LABEL: func @generic_vectorize
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
+// CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<4x256xf32>
+// CHECK-DAG: %[[CST1:.*]] = constant dense<1.000000e+00> : vector<4x256xf32>
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32>
+// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+// CHECK: %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+// CHECK: %[[ADD:.*]] = addf %[[V0B]], %[[V1]] : vector<4x256xf32>
+// CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+// CHECK: %[[CMP:.*]] = cmpf "ogt", %[[V2]], %[[V1]] : vector<4x256xf32>
+// CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32>
+// CHECK: %[[ARG3B:.*]] = vector.broadcast %[[ARG3]] : f32 to vector<4x256xf32>
+// CHECK: %[[DIV:.*]] = divf %[[V3]], %[[ARG3B]] : vector<4x256xf32>
+// CHECK: %[[EXP:.*]] = exp2 %[[V3]] : vector<4x256xf32>
+// CHECK: %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
+// CHECK: %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
+// CHECK: %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
+// CHECK: %[[SUB:.*]] = subf %[[V3]], %[[V1]] : vector<4x256xf32>
+// CHECK: %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
+// CHECK: vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[CST1]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[DIV]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[EXP]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[MUL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[RSQRT]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[SEL]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[SUB]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
+// CHECK: vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 52e96dc44e0b..9e3efcf41664 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -71,7 +71,7 @@ struct TestLinalgTransforms
"Test a fused pass that forwards linalg.copy to vector.transfer"),
llvm::cl::init(false)};
Option<bool> testGenericToVectorPattern{
- *this, "test-contraction-to-vector-patterns",
+ *this, "test-linalg-to-vector-patterns",
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
"in vector.contract form"),
llvm::cl::init(false)};
@@ -464,14 +464,15 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
}
-static void applyContractionToVectorPatterns(FuncOp funcOp) {
+static void applyLinalgToVectorPatterns(FuncOp funcOp) {
OwningRewritePatternList patterns;
- patterns.insert<LinalgVectorizationPattern<BatchMatmulOp>,
- LinalgVectorizationPattern<MatmulOp>,
- LinalgVectorizationPattern<MatvecOp>,
- LinalgVectorizationPattern<VecmatOp>,
- LinalgVectorizationPattern<DotOp>,
- LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
+ patterns.insert<
+ LinalgVectorizationPattern<BatchMatmulOp>,
+ LinalgVectorizationPattern<MatmulOp>,
+ LinalgVectorizationPattern<MatvecOp>,
+ LinalgVectorizationPattern<VecmatOp>, LinalgVectorizationPattern<DotOp>,
+ LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>,
+ LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
@@ -516,7 +517,7 @@ void TestLinalgTransforms::runOnFunction() {
if (testVectorTransferForwardingPatterns)
return applyVectorTransferForwardingPatterns(getFunction());
if (testGenericToVectorPattern)
- return applyContractionToVectorPatterns(getFunction());
+ return applyLinalgToVectorPatterns(getFunction());
if (testAffineMinSCFCanonicalizationPatterns)
return applyAffineMinSCFCanonicalizationPatterns(getFunction());
}