aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNirvedh Meshram <nirvedh@nod-labs.com>2022-09-29 11:14:18 -0700
committerNirvedh Meshram <nirvedh@nod-labs.com>2022-09-29 19:34:42 -0700
commitd6de6dde824596f5daa223f790d8e2206f2e5dce (patch)
treede2287aa2cad1e7fc94648cd1257878b4b232d7f
parent682c95672bec7cfe95d2674377b526b4576293a3 (diff)
downloadllvm-project-d6de6dde824596f5daa223f790d8e2206f2e5dce.tar.gz
[mlir][spirv] Handle dynamic/static cases differntly for kernel capability
Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134908
-rw-r--r--mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp39
-rw-r--r--mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir2
-rw-r--r--mlir/test/Conversion/MemRefToSPIRV/alloc.mlir24
-rw-r--r--mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir30
5 files changed, 71 insertions, 36 deletions
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 7f31488ad2b5..cfb0e5532687 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -335,8 +335,10 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
- // For OpenCL Kernel, pointer will be directly pointing to the element.
- dstType = pointeeType;
+ if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
+ dstType = arrayType.getElementType();
+ else
+ dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
@@ -464,8 +466,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
.getPointeeType();
Type dstType;
if (typeConverter.allows(spirv::Capability::Kernel)) {
- // For OpenCL Kernel, pointer will be directly pointing to the element.
- dstType = pointeeType;
+ if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
+ dstType = arrayType.getElementType();
+ else
+ dstType = pointeeType;
} else {
// For Vulkan we need to extract element from wrapping struct and array.
Type structElemType =
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 3e2d0e9db575..083b997aeaef 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -338,15 +338,16 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- // For OpenCL Kernel we can just emit a pointer pointing to the element.
- if (targetEnv.allows(spirv::Capability::Kernel))
- return spirv::PointerType::get(arrayElemType, storageClass);
- // For Vulkan we need extra wrapping struct and array to satisfy interface
- // needs.
if (!type.hasStaticShape()) {
+ // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
+ // to the element.
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayElemType, storageClass);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
+ // For Vulkan we need extra wrapping struct and array to satisfy interface
+ // needs.
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@@ -354,7 +355,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
-
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayType, storageClass);
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@@ -403,15 +405,16 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- // For OpenCL Kernel we can just emit a pointer pointing to the element.
- if (targetEnv.allows(spirv::Capability::Kernel))
- return spirv::PointerType::get(arrayElemType, storageClass);
- // For Vulkan we need extra wrapping struct and array to satisfy interface
- // needs.
if (!type.hasStaticShape()) {
+ // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
+ // to the element.
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayElemType, storageClass);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
+ // For Vulkan we need extra wrapping struct and array to satisfy interface
+ // needs.
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@@ -425,7 +428,8 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
-
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayType, storageClass);
return wrapInStructAndGetPointer(arrayType, storageClass);
}
@@ -776,15 +780,20 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
auto indexType = typeConverter.getIndexType();
SmallVector<Value, 2> linearizedIndices;
- auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
-
Value linearIndex;
if (baseType.getRank() == 0) {
- linearIndex = zero;
+ linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
} else {
linearIndex =
linearizeIndex(indices, strides, offset, indexType, loc, builder);
}
+ Type pointeeType =
+ basePtr.getType().cast<spirv::PointerType>().getPointeeType();
+ if (pointeeType.isa<spirv::ArrayType>()) {
+ linearizedIndices.push_back(linearIndex);
+ return builder.create<spirv::AccessChainOp>(loc, basePtr,
+ linearizedIndices);
+ }
return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
linearizedIndices);
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
index 35cd5ea26b26..e442ac6c8803 100644
--- a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
@@ -9,7 +9,7 @@ module attributes {
// CHECK: spirv.func
// CHECK-SAME: {{%.*}}: f32
// CHECK-NOT: spirv.interface_var_abi
- // CHECK-SAME: {{%.*}}: !spirv.ptr<f32, CrossWorkgroup>
+ // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.array<12 x f32>, CrossWorkgroup>
// CHECK-NOT: spirv.interface_var_abi
// CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>) kernel
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index d1979e21a436..0e39ee15ea98 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -155,3 +155,27 @@ module attributes {
return
}
}
+
+// -----
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Kernel], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ }
+{
+ func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
+ %0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+ %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+ memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+ memref.dealloc %0 : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+ return
+ }
+}
+// CHECK: spirv.GlobalVariable @[[VAR:.+]] : !spirv.ptr<!spirv.array<20 x f32>, Workgroup>
+// CHECK: func @alloc_dealloc_workgroup_mem
+// CHECK-NOT: memref.alloc
+// CHECK: %[[PTR:.+]] = spirv.mlir.addressof @[[VAR]]
+// CHECK: %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
+// CHECK: %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : f32
+// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
+// CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
+// CHECK-NOT: memref.dealloc
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index e3ddb7398055..67556f00c2f2 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -121,16 +121,16 @@ module attributes {
// CHECK-LABEL: @load_store_zero_rank_float
func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spirv.storage_class<CrossWorkgroup>>) {
- // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
- // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+ // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
+ // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
// CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
- // CHECK: spirv.PtrAccessChain [[ARG0]][
+ // CHECK: spirv.AccessChain [[ARG0]][
// CHECK-SAME: [[ZERO1]]
// CHECK-SAME: ] :
// CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : f32
%0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<CrossWorkgroup>>
// CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
- // CHECK: spirv.PtrAccessChain [[ARG1]][
+ // CHECK: spirv.AccessChain [[ARG1]][
// CHECK-SAME: [[ZERO2]]
// CHECK-SAME: ] :
// CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : f32
@@ -140,16 +140,16 @@ func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<Cr
// CHECK-LABEL: @load_store_zero_rank_int
func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spirv.storage_class<CrossWorkgroup>>) {
- // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
- // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i32, CrossWorkgroup>
+ // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
+ // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
// CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
- // CHECK: spirv.PtrAccessChain [[ARG0]][
+ // CHECK: spirv.AccessChain [[ARG0]][
// CHECK-SAME: [[ZERO1]]
// CHECK-SAME: ] :
// CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : i32
%0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<CrossWorkgroup>>
// CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
- // CHECK: spirv.PtrAccessChain [[ARG1]][
+ // CHECK: spirv.AccessChain [[ARG1]][
// CHECK-SAME: [[ZERO2]]
// CHECK-SAME: ] :
// CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : i32
@@ -173,14 +173,13 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
// CHECK-LABEL: func @load_i1
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index)
func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
- // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
+ // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
- // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
- // CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]]
+ // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
+ // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]]
// CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
// CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
@@ -194,14 +193,13 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
// CHECK-SAME: %[[IDX:.+]]: index
func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i: index) {
%true = arith.constant true
- // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
+ // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
// CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
- // CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
// CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
- // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
- // CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[DST_CAST]][%[[ADD]]]
+ // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
+ // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]]
// CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
// CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
// CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8