Skip to content

Commit a7686db

Browse files
committed
[mlir][gpu] Allow distributing to different level of IDs without failing
Change map_nested_foreach_to_threads to ignore foreach_thread not mapping to threads, this will allow us to call mapNestedForeachToThreadsImpl with different set of ids to lower multiple levels. Also adds warpIds attributes. Differential Revision: https://reviews.llvm.org/D143298
1 parent 5fd51fc commit a7686db

File tree

6 files changed

+77
-29
lines changed

6 files changed

+77
-29
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,27 @@ def GPUThreadMappingAttr
4343
}];
4444
}
4545

46+
def WarpsEnum : I64EnumAttr<"Warps", "threads for loop mapping", [
47+
DimX, DimY, DimZ]> {
48+
let cppNamespace = "::mlir::gpu";
49+
}
50+
51+
def GPUWarpMappingAttr : GPU_Attr<"GPUWarpMapping", "warp", [
52+
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
53+
let parameters = (ins
54+
EnumParameter<WarpsEnum>:$warp
55+
);
56+
let assemblyFormat = "`<` params `>`";
57+
let description = [{
58+
An attribute that allows defining thread block parallelism for GPU devices.
59+
60+
Warp (aka subgroup) are grouped into a grid where grid may be
61+
described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates
62+
that thread block parallelism is desired. It can be consumed by lowering to
63+
generate GPU code.
64+
}];
65+
}
66+
4667
def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
4768
DimX, DimY, DimZ]> {
4869
let cppNamespace = "::mlir::gpu";

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,13 @@ def MapNestedForeachToThreads :
5858
If any scf.foreach_thread with tensors is found, the transform definitely
5959
fails.
6060

61-
If all the scf.foreach_thread operations contained within the LaunchOp
62-
referred to by the `target` PDLOperation lower to GPU properly, the
63-
transform succeeds. Otherwise the transform definitely fails.
61+
If all the scf.foreach_thread operations with gpu.thread mapping contained
62+
within the LaunchOp referred to by the `target` PDLOperation lower to GPU
63+
properly, the transform succeeds. Otherwise the transform definitely
64+
fails.
65+
66+
scf.foreach_thread operations with mappings other than gpu.thread are
67+
ignored.
6468

6569
The returned handle points to the same LaunchOp operand, consuming it and
6670
producing a new SSA value to satisfy chaining and linearity of the IR

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ int64_t GPUBlockMappingAttr::getMappingId() const {
4242
return static_cast<int64_t>(getBlock());
4343
}
4444

45+
int64_t GPUWarpMappingAttr::getMappingId() const {
46+
return static_cast<int64_t>(getWarp());
47+
}
48+
4549
int64_t GPUThreadMappingAttr::getMappingId() const {
4650
return static_cast<int64_t>(getThread());
4751
}

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,12 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
509509
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
510510
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
511511
target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
512+
// Ignore cases with different attributes.
513+
for (Attribute map : foreachThreadOp.getMapping()->getValue()) {
514+
if (!llvm::is_contained(threadMappingAttributes, map)) {
515+
return WalkResult::skip();
516+
}
517+
}
512518
diag = checkAttributeType(threadMappingAttributes,
513519
foreachThreadOp.getMapping(), transformOp);
514520
if (diag.succeeded()) {

mlir/test/Dialect/GPU/transform-gpu-failing.mlir

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -274,30 +274,4 @@ transform.sequence failures(propagate) {
274274
transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]}
275275
}
276276

277-
// -----
278-
279-
!type = memref<32x32xf32>
280-
func.func @saxpy2d_wrong_mapping(%x: !type, %y: !type, %stream : !gpu.async.token) -> !type {
281-
%c32 = arith.constant 32 : index
282-
%one = arith.constant 1 : index
283-
%name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
284-
threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
285-
{
286-
scf.foreach_thread (%i, %j) in (%c32, %c32) {
287-
%4 = memref.load %x[%i, %j] : !type
288-
%5 = memref.load %y[%i, %j] : !type
289-
%6 = arith.mulf %4, %5 : f32
290-
memref.store %6, %y[%i, %j] : !type
291-
} { mapping = [#gpu.block<x>, #gpu.block<x>] }
292-
gpu.terminator
293-
}
294-
return %y : !type
295-
}
296-
297-
transform.sequence failures(propagate) {
298-
^bb1(%arg0: !pdl.operation):
299-
%funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
300-
// expected-error @below {{mapping must be one of #gpu.thread<x>, #gpu.thread<y>, #gpu.thread<z>}}
301-
transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]}
302-
}
303277

mlir/test/Dialect/GPU/transform-gpu.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,42 @@ transform.sequence failures(propagate) {
230230
%funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
231231
transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
232232
}
233+
234+
// -----
235+
236+
!type = memref<2 x 32 x f32>
237+
!type1d = memref<32 x f32>
238+
239+
// CHECK-LABEL: func.func @map_multi_level(
240+
func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
241+
%one = arith.constant 1 : index
242+
%c12 = arith.constant 12 : index
243+
%c9 = arith.constant 9 : index
244+
%c7 = arith.constant 7 : index
245+
// check that the thread level got distributed but not the warp level.
246+
// CHECK-NOT: {mapping = #gpu.thread
247+
// CHECK: {mapping = [#gpu.warp<x>]}
248+
%name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
249+
threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
250+
{
251+
scf.foreach_thread (%i, %j) in (%c7, %c9) {
252+
%4 = memref.load %x[%i, %j] : !type
253+
%5 = memref.load %y[%i, %j] : !type
254+
%6 = math.fma %alpha, %4, %5 : f32
255+
memref.store %6, %y[%i, %j] : !type
256+
} { mapping = [#gpu.thread<y>, #gpu.thread<x>]}
257+
scf.foreach_thread (%i) in (%c12) {
258+
%7 = memref.load %t[%i] : !type1d
259+
%8 = arith.addf %alpha, %7 : f32
260+
memref.store %8, %t[%i] : !type1d
261+
} {mapping = [#gpu.warp<x>] }
262+
gpu.terminator
263+
}
264+
return %y : !type
265+
}
266+
267+
transform.sequence failures(propagate) {
268+
^bb1(%arg0: !pdl.operation):
269+
%funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
270+
transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] }
271+
}

0 commit comments

Comments
 (0)