Skip to content

Commit 99c9106

Browse files
[Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon. PiperOrigin-RevId: 735877661
1 parent 67aa997 commit 99c9106

File tree

5 files changed

+112
-47
lines changed

5 files changed

+112
-47
lines changed

jax/experimental/mosaic/gpu/dialect_lowering.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -259,14 +259,15 @@ def _vector_load_op_lowering_rule(
259259
is_signed=is_signed,
260260
vec_size=strided_layout.vec_size,
261261
)
262-
elif layouts.is_wgmma_fragmented_layout(out_layout_attr):
262+
elif layouts.from_layout_attr(out_layout_attr) == fa.TILED_LAYOUT_WGMMA:
263263
layout = ir.MemRefType(vector_load_op.base.type).layout
264264
swizzle, transforms = memref_layout_to_swizzle_and_transforms(layout)
265265
transformed_ref = transform_memref(vector_load_op.base, transforms)
266266
fragmented_array = fa.FragmentedArray.load_tiled(
267267
transformed_ref,
268268
swizzle=swizzle,
269-
is_signed=is_signed
269+
is_signed=is_signed,
270+
layout=fa.TILED_LAYOUT_WGMMA,
270271
)
271272
else:
272273
raise ValueError(
@@ -634,7 +635,10 @@ def _mgpu_wgmma_op_lowering_rule(
634635
*inference_utils.in_layouts(wgmma_op),
635636
*inference_utils.out_layouts(wgmma_op),
636637
)
637-
if not all(map(layouts.is_wgmma_fragmented_layout, fa_layouts)):
638+
is_supported_layout = (
639+
lambda l: layouts.from_tiled_layout_attr(l) == fa.TILED_LAYOUT_WGMMA
640+
)
641+
if not all(map(is_supported_layout, fa_layouts)):
638642
raise ValueError("Layout mismatch")
639643
wgmma_layout = fa_layouts[0]
640644

@@ -667,7 +671,12 @@ def _mgpu_wgmma_op_lowering_rule(
667671

668672
new_acc = wgmma.wgmma(acc, a_operand, b_operand, swizzle=b_swizzle)
669673

670-
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]
674+
return [
675+
_fragmented_array_to_ir(
676+
new_acc.value.to_layout(fa.TILED_LAYOUT_WGMMA),
677+
wgmma_op.accumulator.type,
678+
)
679+
]
671680

672681

673682
@_register_lowering(mgpu.ArriveExpectTxOp)

jax/experimental/mosaic/gpu/layout_inference.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _choose_representative_layout(
6363
6464
Given the input set of possible layouts, this function extracts a single
6565
representative layout. Currently, this function only works with strided,
66-
splat, and WGMMA fragmented layouts.
66+
splat, and tiled layouts.
6767
6868
Returns:
6969
A single layout that can be used to annotate the operation, or None if the
@@ -86,18 +86,18 @@ def _choose_representative_layout(
8686
)
8787
)
8888

89-
wgmma_layouts: list[fa.WGMMAFragLayout] = list(
89+
tiled_layouts: list[fa.TiledLayout] = list(
9090
map(
9191
layouts_lib.from_layout_attr,
92-
filter(layouts_lib.is_wgmma_fragmented_layout, layouts),
92+
filter(layouts_lib.is_tiled_layout, layouts),
9393
)
9494
)
9595

96-
if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len(
96+
if len(splat_layouts) + len(strided_layouts) + len(tiled_layouts) != len(
9797
layouts
9898
):
9999
raise ValueError(
100-
f"Expected only strided, splat, and wgmma layouts, got {layouts}"
100+
f"Expected only strided, splat, and tiled layouts, got {layouts}"
101101
)
102102

103103
if len(splat_layouts) > 1:
@@ -112,13 +112,19 @@ def _choose_representative_layout(
112112
"is not supported."
113113
)
114114

115-
if (wgmma_layouts and strided_layouts):
115+
if len(tiled_layouts) > 1:
116116
raise NotImplementedError(
117-
"Mixing strided and WGMMA layouts is not supported."
117+
"Finding a representative layout for several distinct tiled layouts "
118+
"is not supported."
119+
)
120+
121+
if tiled_layouts and strided_layouts:
122+
raise NotImplementedError(
123+
"Mixing strided and tiled layouts is not supported."
118124
)
119125

120-
if wgmma_layouts:
121-
return layouts_lib.to_layout_attr(wgmma_layouts[0])
126+
if tiled_layouts:
127+
return layouts_lib.to_layout_attr(tiled_layouts[0])
122128

123129
if strided_layouts:
124130
[strided_layout] = strided_layouts
@@ -333,7 +339,7 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
333339

334340
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
335341
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
336-
layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout())
342+
layout = layouts_lib.to_layout_attr(fa.TILED_LAYOUT_WGMMA)
337343

338344
if ir.VectorType.isinstance(wgmma_op.a.type):
339345
return [layout, layout], [layout]

jax/experimental/mosaic/gpu/layouts.py

+62-15
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,67 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
9494
return bool(_strided_fragmented_layout_attr_pattern.search(str(attr)))
9595

9696

97+
_tiled_layout_attr_pattern = re.compile(
98+
r"^#mosaic_gpu.TiledLayout<\[(?P<tiling>.*)\],"
99+
r" warp_dim\s*=\s*(?P<warp_dim>[-\d]+),"
100+
r" lane_dims\s*=\s*\[(?P<lane_dims>.*)\],"
101+
r" vector_dim\s*=\s*(?P<vector_dim>[-\d]+)>$"
102+
)
103+
104+
105+
def to_tiled_layout_attr(
106+
layout: fa.TiledLayout,
107+
) -> ir.Attribute:
108+
"""Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
109+
110+
tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]"
111+
tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]"
112+
return ir.Attribute.parse(
113+
f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim},"
114+
f" lane_dims={list(layout.lane_dims)}, vector_dim={layout.vector_dim}>"
115+
)
116+
117+
118+
_list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[")
119+
120+
121+
def from_tiled_layout_attr(
122+
attr: ir.Attribute,
123+
) -> fa.TiledLayout:
124+
"""Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute.
125+
126+
Raises:
127+
ValueError: If the attribute is not a #mosaic_gpu.TiledLayout
128+
attribute.
129+
"""
130+
match = _tiled_layout_attr_pattern.fullmatch(str(attr))
131+
if not match:
132+
raise ValueError(
133+
f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}"
134+
)
135+
136+
tiling_str = match.group("tiling")
137+
tile_strings = []
138+
if len(tiling_str) > 2:
139+
tile_strings = _list_of_lists_delimiter.split(tiling_str[1:-1])
140+
tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings)
141+
return fa.TiledLayout(
142+
tiling=fa.Tiling(tiles),
143+
warp_dim=int(match.group("warp_dim")),
144+
lane_dims=tuple(int(s) for s in match.group("lane_dims").split(",")),
145+
vector_dim=int(match.group("vector_dim"))
146+
)
147+
148+
149+
def is_tiled_layout(attr: ir.Attribute) -> bool:
150+
return bool(_tiled_layout_attr_pattern.search(str(attr)))
151+
152+
97153
def to_layout_attr(
98154
layout: (
99155
fa.WGSplatFragLayout
100156
| fa.WGStridedFragLayout
101-
| fa.WGMMAFragLayout
157+
| fa.TiledLayout
102158
| fa.WGMMARowFragLayout
103159
),
104160
) -> ir.Attribute:
@@ -108,8 +164,8 @@ def to_layout_attr(
108164
return to_splat_fragmented_layout_attr(layout)
109165
case fa.WGStridedFragLayout():
110166
return to_strided_fragmented_layout_attr(layout)
111-
case fa.WGMMAFragLayout():
112-
return ir.Attribute.parse("#mosaic_gpu.WGMMAFragLayout")
167+
case fa.TiledLayout():
168+
return to_tiled_layout_attr(layout)
113169
case fa.WGMMARowFragLayout():
114170
return ir.Attribute.parse("#mosaic_gpu.WGMMARowFragLayout")
115171
case _:
@@ -118,15 +174,6 @@ def to_layout_attr(
118174
)
119175

120176

121-
_wgmma_fragmented_layout_attr_pattern = re.compile(
122-
r"^#mosaic_gpu.WGMMAFragLayout$"
123-
)
124-
125-
126-
def is_wgmma_fragmented_layout(attr: ir.Attribute) -> bool:
127-
return bool(_wgmma_fragmented_layout_attr_pattern.search(str(attr)))
128-
129-
130177
_wgmma_row_fragmented_layout_attr_pattern = re.compile(
131178
r"^#mosaic_gpu.WGMMARowFragLayout$"
132179
)
@@ -141,16 +188,16 @@ def from_layout_attr(
141188
) -> (
142189
fa.WGSplatFragLayout
143190
| fa.WGStridedFragLayout
144-
| fa.WGMMAFragLayout
191+
| fa.TiledLayout
145192
| fa.WGMMARowFragLayout
146193
):
147194
"""Constructs a layout from an MLIR attribute."""
148195
if is_splat_fragmented_layout(attr):
149196
return from_splat_fragmented_layout_attr(attr)
150197
elif is_strided_fragmented_layout(attr):
151198
return from_strided_fragmented_layout_attr(attr)
152-
elif is_wgmma_fragmented_layout(attr):
153-
return fa.WGMMAFragLayout()
199+
elif is_tiled_layout(attr):
200+
return from_tiled_layout_attr(attr)
154201
elif is_wgmma_row_fragmented_layout(attr):
155202
return fa.WGMMARowFragLayout()
156203
else:

jaxlib/mosaic/dialect/gpu/mosaic_gpu.td

+18-15
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def MosaicGPU_WGStridedFragLayout : AttrDef<MosaicGPU_Dialect, "WGStridedFragLay
128128
let assemblyFormat = "`<` $shape`,` $vector_size `>`";
129129
}
130130

131-
132131
def MosaicGPU_WGSplatFragLayout : AttrDef<MosaicGPU_Dialect, "WGSplatFragLayout", []> {
133132
let summary = "Annotates an array that is the result of a splat.";
134133
let description = [{
@@ -143,20 +142,6 @@ def MosaicGPU_WGSplatFragLayout : AttrDef<MosaicGPU_Dialect, "WGSplatFragLayout"
143142
let assemblyFormat = "`<` $shape `>`";
144143
}
145144

146-
def MosaicGPU_WGMMAFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMAFragLayout", []> {
147-
let summary = "2D array that can be tiled by supported WGMMA shapes.";
148-
let description = [{
149-
This layout annotates arrays that are fragmented across all threads in a
150-
warpgroup that is executing a WGMMA operation. The shape of the array is
151-
(m, n) where:
152-
- m % 64 == 0
153-
- n % 8 == 0
154-
}];
155-
156-
let mnemonic = "WGMMAFragLayout";
157-
let assemblyFormat = "";
158-
}
159-
160145
def MosaicGPU_WGMMARowFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMARowFragLayout", []> {
161146
let summary = "1D array that is a row that can be tiled by supported WGMMA shapes.";
162147
let description = [{
@@ -169,6 +154,24 @@ def MosaicGPU_WGMMARowFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMARowFragLayou
169154
let assemblyFormat = "";
170155
}
171156

157+
def MosaicGPU_TiledLayout : AttrDef<MosaicGPU_Dialect, "TiledLayout", []> {
158+
let summary = "A layout derived from a tiling expression.";
159+
let description = [{
160+
See mosaic/gpu/fragmented_array.py -> TiledLayout for more details.
161+
}];
162+
163+
let parameters = (ins
164+
"::mlir::ArrayAttr":$tiling,
165+
"int":$warp_dim,
166+
"::mlir::ArrayAttr":$lane_dims,
167+
"int":$vector_dim
168+
);
169+
let mnemonic = "TiledLayout";
170+
let assemblyFormat = "`<` $tiling `,` `warp_dim` `=` $warp_dim `,` "
171+
"`lane_dims` `=` $lane_dims `,` `vector_dim` `=` $vector_dim `>`";
172+
}
173+
174+
172175
// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td
173176
// but it was not possible to reuse that definition. Including that file
174177
// pulls in ops definitions that we don't want and they fail to compile.

tests/mosaic/gpu_layout_inference_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def body(a, b):
210210
for layout in [
211211
mgpu.WGSplatFragLayout(shape),
212212
mgpu.WGStridedFragLayout(shape, vec_size=4),
213-
mgpu.WGMMAFragLayout(),
213+
mgpu.TILED_LAYOUT_WGMMA,
214214
]
215215
)
216216
def test_infer_layout_from_yield_op_in_layouts_for_for_op(
@@ -278,7 +278,7 @@ def body(lower_bound, upper_bound, step, a, b, c):
278278

279279
mgpu.infer_layout(self.module)
280280

281-
wgmma_layout = layouts.to_layout_attr(mgpu.WGMMAFragLayout())
281+
wgmma_layout = layouts.to_layout_attr(mgpu.TILED_LAYOUT_WGMMA)
282282
self.assertSequenceEqual(yield_op.attributes["in_layouts"], [wgmma_layout])
283283
self.assertSequenceEqual(yield_op.attributes["out_layouts"], [])
284284
self.assertSequenceEqual(for_op.attributes["in_layouts"], [wgmma_layout])
@@ -312,7 +312,7 @@ def body(ref, array):
312312

313313
@parameterized.parameters(
314314
mgpu.WGStridedFragLayout((32, 4), vec_size=1),
315-
mgpu.WGMMAFragLayout(),
315+
mgpu.TILED_LAYOUT_WGMMA,
316316
)
317317
def test_infer_layout_picks_non_splat_layout_over_splat_layout(
318318
self, layout

0 commit comments

Comments
 (0)