Skip to content

Commit e33f3fc

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that * we have no easy way of testing multi-reductions at the moment; * `reduce_max` assumes WGMMA_ROW layout which is not currently supported by the dialect lowering AFAICT. PiperOrigin-RevId: 736138554
1 parent d89835a commit e33f3fc

File tree

5 files changed

+124
-9
lines changed

5 files changed

+124
-9
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

+54
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,60 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
15431543
raise NotImplementedError(f"Unsupported layout {x.layout}")
15441544

15451545

1546+
def _reduce_lowering_rule_wg(
1547+
kind: vector_dialect.CombiningKind,
1548+
acc: object,
1549+
ctx: LoweringRuleContext,
1550+
x,
1551+
*,
1552+
axes,
1553+
) -> ir.OpView:
1554+
[x_aval] = ctx.avals_in
1555+
[out_aval] = ctx.avals_out
1556+
x = _ensure_ir_value(x, x_aval.dtype)
1557+
out_type = mgpu_utils.dtype_to_ir_type(out_aval.dtype)
1558+
if not out_aval.shape:
1559+
# Special-case: reducing to a scalar.
1560+
if x_aval.ndim != 1:
1561+
# TODO(slebedev): Flatten to 1D, since vector.reduction only supports
1562+
# 1D inputs.
1563+
raise NotImplementedError("Only 1D inputs are supported")
1564+
return vector_dialect.ReductionOp(out_type, kind, x)
1565+
acc = vector_dialect.splat(
1566+
ir.VectorType.get(out_aval.shape, out_type),
1567+
_ensure_ir_value(acc, out_aval.dtype),
1568+
)
1569+
return vector_dialect.MultiDimReductionOp(kind, x, acc, axes)
1570+
1571+
1572+
@register_lowering_rule(lax.reduce_sum_p, mgpu.ThreadSemantics.Warpgroup)
1573+
def _reduce_sum_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
1574+
op = _reduce_lowering_rule_wg(
1575+
vector_dialect.CombiningKind.ADD, 0, ctx, x, axes=axes
1576+
)
1577+
op.attributes["offset"] = ir.IntegerAttr.get(
1578+
ir.IntegerType.get_signless(32), ctx.module_ctx.smem_used_bytes
1579+
)
1580+
return op.result
1581+
1582+
1583+
@register_lowering_rule(lax.reduce_max_p, mgpu.ThreadSemantics.Warpgroup)
1584+
def _reduce_max_lowering_rule_wg(ctx: LoweringRuleContext, x, *, axes):
1585+
[x_aval] = ctx.avals_in
1586+
if jnp.issubdtype(x_aval.dtype, jnp.floating):
1587+
kind = vector_dialect.CombiningKind.MAXIMUMF
1588+
acc = float("-inf")
1589+
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
1590+
kind = vector_dialect.CombiningKind.MAXSI
1591+
acc = np.iinfo(x_aval.dtype).max
1592+
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
1593+
kind = vector_dialect.CombiningKind.MAXUI
1594+
acc = np.iinfo(x_aval.dtype).max
1595+
else:
1596+
raise NotImplementedError(f"Unsupported dtype {x_aval.dtype}")
1597+
return _reduce_lowering_rule_wg(kind, acc, ctx, x, axes=axes).result
1598+
1599+
15461600
@register_lowering_rule(lax.axis_index_p, mgpu.ThreadSemantics.Lane)
15471601
def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable):
15481602
i32 = ir.IntegerType.get_signless(32)

jax/experimental/mosaic/gpu/dialect_lowering.py

+36-6
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,34 @@ def _vector_splat_op_lowering_rule(
320320
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
321321

322322

323+
@_register_lowering(vector.ReductionOp)
324+
def _vector_reduction_op_lowering_rule(
325+
ctx: LoweringContext, op: vector.ReductionOp
326+
) -> Sequence[ir.Value]:
327+
del ctx # Unused.
328+
[layout] = inference_utils.in_layouts(op)
329+
() = inference_utils.out_layouts(op)
330+
element_type = ir.VectorType(op.vector.type).element_type
331+
is_signed = False if ir.IntegerType.isinstance(element_type) else None
332+
a = _fragmented_array_from_ir(op.vector, layout, is_signed)
333+
match str(op.kind):
334+
case "#vector.kind<add>":
335+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
336+
scratch = _slice_smem(
337+
ir.MemRefType.get([4], element_type, memory_space=smem),
338+
arith.constant(None, op.attributes["offset"]),
339+
)
340+
result = a.reduce_sum(scratch)
341+
case (
342+
"#vector.kind<maxsi>" | "#vector.kind<maxui>" | "#vector.kind<maximumf>"
343+
):
344+
# TODO(slebedev): Implement this and remove the raise below.
345+
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
346+
case _:
347+
raise NotImplementedError(f"Unsupported reduction kind: {op.kind}")
348+
return [_fragmented_array_to_ir(result, op.result.type)]
349+
350+
323351
def memref_layout_to_swizzle_and_transforms(
324352
layout: ir.Attribute,
325353
) -> tuple[mgpu.SwizzlingMode, tuple[launch_context.MemRefTransform, ...]]:
@@ -713,16 +741,17 @@ def _mgpu_slice_smem_op_lowering_rule(
713741
ctx: LoweringContext, op: SliceSMEMOp
714742
) -> Sequence[ir.Value]:
715743
del ctx
744+
return [_slice_smem(op.result.type, op.offset)]
745+
746+
747+
def _slice_smem(result: ir.Type, offset: ir.Value):
716748
i8 = ir.IntegerType.get_signless(8)
717749
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
718-
719750
smem_base = gpu.dynamic_shared_memory(
720751
ir.MemRefType.get((utils.DYNAMIC,), i8, memory_space=smem)
721752
)
722-
723-
offset = arith.index_cast(ir.IndexType.get(), op.offset)
724-
725-
return [memref.view(op.result.type, smem_base, offset, [])]
753+
offset = arith.index_cast(ir.IndexType.get(), offset)
754+
return memref.view(result, smem_base, offset, [])
726755

727756

728757
@_register_lowering(scf.ForOp)
@@ -866,7 +895,8 @@ def _should_lower(op: ir.OpView) -> bool:
866895

867896

868897
def lower_mgpu_dialect(
869-
module: ir.Module, launch_context: launch_context.LaunchContext | None
898+
module: ir.Module,
899+
launch_context: launch_context.LaunchContext | None,
870900
):
871901
# TODO(apaszke,bchetioui): Make sure the layouts match.
872902
# TODO(bchetioui): rethink this API. It doesn't make sense to pass in a full

jax/experimental/mosaic/gpu/fragmented_array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,7 @@ def reduce_sum(self, scratch: ir.Value | None = None):
13891389
if isinstance(self.layout, WGSplatFragLayout):
13901390
[reg] = self.registers.flat
13911391
if ir.FloatType.isinstance(self.mlir_dtype):
1392-
op = arith.mulf
1392+
op = mulf
13931393
elif ir.IntegerType.isinstance(self.mlir_dtype):
13941394
op = arith.muli
13951395
else:

jax/experimental/mosaic/gpu/layout_inference.py

+6
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,12 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
336336

337337
return [], [layout]
338338

339+
@partial(_add_layout_inference_rule, vector.ReductionOp)
340+
def _infer_reduction_op_layout(op: vector.ReductionOp) -> OptionalLayouts:
341+
if layout := inference_utils.value_layout(op.vector):
342+
return [layout], []
343+
return None
344+
339345

340346
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
341347
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:

tests/pallas/mosaic_gpu_test.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,23 @@ def kernel(x_ref, y_ref, o_ref):
184184
y = jnp.flip(x).reshape(1, 256)
185185
np.testing.assert_array_equal(kernel(x, y), x + y[0])
186186

187+
@parameterized.product(
188+
shape=[(128,)], thread_semantics=[*plgpu.ThreadSemantics]
189+
)
190+
def test_reduce_sum(self, shape, thread_semantics):
191+
@functools.partial(
192+
pl.pallas_call,
193+
out_shape=jax.ShapeDtypeStruct(shape, jnp.float32),
194+
compiler_params=plgpu.GPUCompilerParams(
195+
thread_semantics=thread_semantics
196+
),
197+
)
198+
def kernel(x_ref, o_ref):
199+
o_ref[...] = jnp.broadcast_to(_sum_same_dtype(x_ref[...]), o_ref.shape)
200+
201+
x = jnp.arange(math.prod(shape)).reshape(shape).astype(jnp.float32)
202+
np.testing.assert_array_equal(kernel(x), jnp.sum(x))
203+
187204
def test_reshape(self):
188205
shape1, shape2 = (128,), (2, 16, 4)
189206

@@ -200,10 +217,14 @@ def kernel(x_ref, out_ref):
200217
x = jnp.arange(math.prod(shape1)).astype(jnp.float32)
201218
np.testing.assert_array_equal(kernel(x), x.reshape(shape2))
202219

203-
def test_add_xy_indexed(self):
220+
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
221+
def test_add_xy_indexed(self, thread_semantics):
204222
@functools.partial(
205223
pl.pallas_call,
206224
out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
225+
compiler_params=plgpu.GPUCompilerParams(
226+
thread_semantics=thread_semantics
227+
),
207228
)
208229
def kernel(x_ref, y_ref, o_ref):
209230
idx = _sum_same_dtype(y_ref[...])
@@ -1078,10 +1099,14 @@ def kernel(x_ref, o_ref):
10781099

10791100
self.assertIn("acc % 2", output())
10801101

1081-
def test_cond_returning_array(self):
1102+
@parameterized.parameters([*plgpu.ThreadSemantics])
1103+
def test_cond_returning_array(self, thread_semantics):
10821104
@functools.partial(
10831105
pl.pallas_call,
10841106
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
1107+
compiler_params=plgpu.GPUCompilerParams(
1108+
thread_semantics=thread_semantics
1109+
),
10851110
)
10861111
def kernel(x_ref, o_ref):
10871112
acc = _sum_same_dtype(x_ref[...])

0 commit comments

Comments
 (0)