aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
blob: 155247a53806935b3ecb7372ff9bfe1cd150f270 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-contraction-to-vector-patterns | FileCheck %s --check-prefix=VECTOR-CONTRACTION

func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                  %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                  %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
  linalg.matmul {__internal_linalg_transform__ = "START"}
    ins(%A, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
                memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
   outs(%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
  return
}

// CHECK-LABEL:func @matmul
//      CHECK: vector.transfer_write {{.*}} : vector<8x16xf32>, memref<8x16xf32>
//      CHECK: vector.transfer_write {{.*}} : vector<16x12xf32>, memref<16x12xf32>
//      CHECK: vector.transfer_write {{.*}} : vector<8x12xf32>, memref<8x12xf32>
//
//      CHECK: linalg.copy
//      CHECK: linalg.copy
//      CHECK: linalg.copy
//
//      CHECK: vector.contract
// CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME:   : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
//
//      CHECK: linalg.copy

// VECTOR-CONTRACTION-LABEL: contraction_dot
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
  // VECTOR-CONTRACTION: vector.contract
  // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
  linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
            outs(%C: memref<f32>)
  return
}

// VECTOR-CONTRACTION-LABEL: contraction_matvec
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
  // VECTOR-CONTRACTION: vector.contract
  // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
  linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
            outs(%C: memref<1584xf32>)
  return
}

// VECTOR-CONTRACTION-LABEL: contraction_matmul
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
  // VECTOR-CONTRACTION: vector.contract
  // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
  linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
            outs(%C: memref<1584x1584xf32>)
  return
}

// VECTOR-CONTRACTION-LABEL: contraction_batch_matmul
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
  // VECTOR-CONTRACTION: vector.contract
  // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
  linalg.batch_matmul
    ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
   outs(%C: memref<1584x1584x1584xf32>)
  return
}