Skip to content

Commit 75d8702

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing. PiperOrigin-RevId: 735293084
1 parent b9fb69d commit 75d8702

File tree

6 files changed

+257
-32
lines changed

6 files changed

+257
-32
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,9 @@ def _convert_element_type_lowering_rule_wg(
11621162
cur_dtype = mgpu_utils.dtype_to_ir_type(x_aval.dtype)
11631163
new_dtype = mgpu_utils.dtype_to_ir_type(new_dtype)
11641164

1165+
if cur_dtype == new_dtype:
1166+
return x
1167+
11651168
if 1 < mgpu_utils.bitwidth(cur_dtype) < 8 or 1 < mgpu_utils.bitwidth(new_dtype) < 8:
11661169
raise NotImplementedError("Conversion involving sub-byte types unsupported")
11671170

@@ -1170,7 +1173,29 @@ def _convert_element_type_lowering_rule_wg(
11701173
from_integer = ir.IntegerType.isinstance(cur_dtype)
11711174
to_integer = ir.IntegerType.isinstance(new_dtype)
11721175
if from_float and to_float:
1173-
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
1176+
cur_ty_width = ir.FloatType(cur_dtype).width
1177+
new_ty_width = ir.FloatType(new_dtype).width
1178+
if cur_ty_width == new_ty_width:
1179+
# There is no instruction to perform conversions between two float types
1180+
# of the same width. Go through the next-larger standard type.
1181+
# TODO(bchetioui): support conversions between float types of width 8.
1182+
# Which larger type to pick will depend on the number of bits in the
1183+
# smallest exponent.
1184+
if cur_ty_width != 16:
1185+
raise NotImplementedError(
1186+
"Conversion between float types of width other than 16 not"
1187+
" supported"
1188+
)
1189+
larger_ty = ir.F32Type.get()
1190+
if x_aval.shape:
1191+
upcast_ty = ir.VectorType.get(x_aval.shape, larger_ty)
1192+
else:
1193+
upcast_ty = larger_ty
1194+
1195+
def convert(ty, x):
1196+
return arith_dialect.truncf(ty, arith_dialect.extf(upcast_ty, x))
1197+
1198+
elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
11741199
convert = arith_dialect.truncf
11751200
else:
11761201
convert = arith_dialect.extf
@@ -1190,10 +1215,26 @@ def _convert_element_type_lowering_rule_wg(
11901215
else:
11911216
convert = arith_dialect.uitofp
11921217
elif from_float and to_integer:
1218+
dst_width = mgpu_utils.bitwidth(new_dtype)
1219+
# We clamp the float value to the min/max integer destination value
1220+
# in order to match JAX/XLA casting behavior. Note that this differs
1221+
# from numpy casting behavior.
11931222
if mgpu_utils.is_signed(y_aval.dtype):
1223+
maxint = 2 ** (dst_width - 1) - 1
1224+
minint = -(2 ** (dst_width - 1))
11941225
convert = arith_dialect.fptosi
11951226
else:
1227+
maxint = 2**dst_width - 1
1228+
minint = 0
11961229
convert = arith_dialect.fptoui
1230+
1231+
maxint = _ir_constant(maxint, cur_dtype)
1232+
minint = _ir_constant(minint, cur_dtype)
1233+
if x_aval.shape:
1234+
maxint = vector_dialect.splat(x.type, maxint)
1235+
minint = vector_dialect.splat(x.type, minint)
1236+
x = arith_dialect.minimumf(x, maxint)
1237+
x = arith_dialect.maximumf(x, minint)
11971238
else:
11981239
raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}")
11991240

jax/experimental/mosaic/gpu/dialect_lowering.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ def _vector_load_op_lowering_rule(
253253

254254
element_type = vector_load_op.result.type.element_type
255255
is_signed = False if ir.IntegerType.isinstance(element_type) else None
256-
256+
strided_layout = layouts.from_strided_fragmented_layout_attr(out_layout_attr)
257257
fragmented_array = fa.FragmentedArray.load_strided(
258-
vector_load_op.base, is_signed=is_signed
258+
vector_load_op.base, is_signed=is_signed, vec_size=strided_layout.vec_size
259259
)
260260
return [_fragmented_array_to_ir(fragmented_array, vector_load_op.result.type)]
261261

@@ -429,6 +429,41 @@ def _mgpu_async_store_op_lowering_rule(
429429
return []
430430

431431

432+
def _conversion_op_lowering_rule(
433+
_: LoweringContext,
434+
op: ir.OpView,
435+
source_is_signed: bool | None,
436+
target_is_signed: bool | None,
437+
) -> Sequence[ir.Value]:
438+
[in_layout] = inference_utils.in_layouts(op)
439+
[layout] = inference_utils.out_layouts(op)
440+
if in_layout != layout:
441+
raise ValueError("Layout mismatch")
442+
443+
target_ty = op.result.type.element_type # pytype: disable=attribute-error
444+
operand = _fragmented_array_from_ir(op.operands[0], layout, source_is_signed)
445+
converted = operand.astype(target_ty, is_signed=target_is_signed)
446+
return [_fragmented_array_to_ir(converted, op.result.type)]
447+
448+
449+
for op, source_is_signed, target_is_signed in [
450+
(arith.ExtFOp, None, None),
451+
(arith.ExtSIOp, True, True),
452+
(arith.ExtUIOp, False, False),
453+
(arith.FPToSIOp, None, True),
454+
(arith.FPToUIOp, None, False),
455+
(arith.SIToFPOp, True, None),
456+
(arith.TruncFOp, None, None),
457+
(arith.TruncIOp, False, False),
458+
(arith.UIToFPOp, False, None),
459+
]:
460+
_lowerings[op.OPERATION_NAME] = functools.partial(
461+
_conversion_op_lowering_rule,
462+
source_is_signed=source_is_signed,
463+
target_is_signed=target_is_signed,
464+
)
465+
466+
432467
def _binary_op_lowering_rule(
433468
_: LoweringContext,
434469
op: Any,

jax/experimental/mosaic/gpu/fragmented_array.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -641,13 +641,22 @@ def __init__(
641641
raise NotImplementedError
642642

643643
@classmethod
644-
def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None):
644+
def load_strided(
645+
cls,
646+
ref: ir.Value,
647+
*,
648+
is_signed: bool | None = None,
649+
vec_size: int | None = None,
650+
):
645651
if not ir.MemRefType.isinstance(ref.type):
646652
raise TypeError(ref.type)
647653

648654
ref_ty = ir.MemRefType(ref.type)
649655
shape = tuple(ref_ty.shape)
650-
layout = WGStridedFragLayout.from_shaped_type(ref_ty)
656+
if vec_size is None:
657+
layout = WGStridedFragLayout.from_shaped_type(ref_ty)
658+
else:
659+
layout = WGStridedFragLayout(shape=shape, vec_size=vec_size)
651660
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
652661
try:
653662
# Flattening the reference potentially produces simpler PTX but
@@ -1322,7 +1331,30 @@ def upcast_to_bf16(reg, high):
13221331
from_integer = ir.IntegerType.isinstance(cur_dtype)
13231332
to_integer = ir.IntegerType.isinstance(new_dtype)
13241333
if from_float and to_float:
1325-
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
1334+
cur_ty_width = ir.FloatType(cur_dtype).width
1335+
new_ty_width = ir.FloatType(new_dtype).width
1336+
if cur_ty_width == new_ty_width:
1337+
# There is no instruction to perform conversions between two float types
1338+
# of the same width. Go through the next-larger standard type.
1339+
# TODO(bchetioui): support conversions between float types of width 8.
1340+
# Which larger type to pick will depend on the number of bits in the
1341+
# smallest exponent.
1342+
if cur_ty_width != 16:
1343+
raise NotImplementedError(
1344+
"Conversion between float types of width other than 16 not"
1345+
" supported"
1346+
)
1347+
larger_ty = ir.F32Type.get()
1348+
match self.layout:
1349+
case WGMMAFragLayout() | WGStridedFragLayout() | TiledLayout():
1350+
shape = ir.VectorType(self.registers.flat[0].type).shape
1351+
upcast_ty = ir.VectorType.get(shape, larger_ty)
1352+
case WGMMARowFragLayout() | WGSplatFragLayout():
1353+
upcast_ty = larger_ty
1354+
case _:
1355+
raise NotImplementedError(f"Unsupported layout {self.layout}")
1356+
convert = lambda ty, x: arith.truncf(ty, arith.extf(upcast_ty, x))
1357+
elif ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
13261358
convert = arith.truncf
13271359
else:
13281360
convert = arith.extf

jax/experimental/mosaic/gpu/layout_inference.py

+56-12
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323
from jax._src.lib import mosaic_gpu_dialect as mgpu
2424
from jax._src.lib.mlir import ir
2525
from jax._src.lib.mlir.dialects import arith
26-
from jax._src.lib.mlir.dialects import scf
2726
from jax._src.lib.mlir.dialects import memref
27+
from jax._src.lib.mlir.dialects import scf
2828
from jax._src.lib.mlir.dialects import vector
29+
import math
30+
import numpy as np
2931

3032
from . import fragmented_array as fa
3133
from . import inference_utils
3234
from . import layouts as layouts_lib
35+
from . import utils
3336

3437
# mypy: ignore-errors
3538

@@ -192,22 +195,38 @@ def is_array(v: ir.Value) -> bool:
192195

193196

194197
for op in [
195-
arith.AddIOp, arith.AddFOp,
198+
arith.AddIOp,
199+
arith.AddFOp,
196200
arith.AndIOp,
197201
arith.BitcastOp,
198202
arith.CmpFOp,
199203
arith.CmpIOp,
200-
arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp,
204+
arith.ExtFOp,
205+
arith.ExtSIOp,
206+
arith.ExtUIOp,
207+
arith.FPToSIOp,
208+
arith.FPToUIOp,
201209
arith.MaximumFOp,
202-
arith.MaxUIOp, arith.MaxSIOp,
210+
arith.MaxUIOp,
211+
arith.MaxSIOp,
203212
arith.MinimumFOp,
204-
arith.MinUIOp, arith.MinSIOp,
205-
arith.MulIOp, arith.MulFOp,
213+
arith.MinUIOp,
214+
arith.MinSIOp,
215+
arith.MulIOp,
216+
arith.MulFOp,
206217
arith.OrIOp,
207-
arith.FloorDivSIOp, arith.DivUIOp, arith.DivFOp,
208-
arith.RemUIOp, arith.RemSIOp, arith.RemFOp,
209-
arith.SubIOp, arith.SubFOp,
210-
arith.TruncFOp, arith.TruncIOp,
218+
arith.FloorDivSIOp,
219+
arith.DivUIOp,
220+
arith.DivFOp,
221+
arith.RemUIOp,
222+
arith.RemSIOp,
223+
arith.RemFOp,
224+
arith.SIToFPOp,
225+
arith.UIToFPOp,
226+
arith.SubIOp,
227+
arith.SubFOp,
228+
arith.TruncFOp,
229+
arith.TruncIOp,
211230
arith.XOrIOp,
212231
vector.LoadOp,
213232
vector.StoreOp,
@@ -488,11 +507,36 @@ def inference_step(op: ir.Operation):
488507
# propagated. However, it is possible for some operations to remain
489508
# unannotated---for example, if there were no annotations on any operation in
490509
# the module at the start of this function. We annotate all the remaining ops
491-
# that should be annotated with a strided fragmented layout.
510+
# that should be annotated with a strided fragmented layout, whose vector size
511+
# is derived from the narrowest type and vector size used in the program. We
512+
# make sure to derive a single vector size in order to avoid relayouts at
513+
# lowering time.
514+
default_vector_size = math.inf
515+
516+
def update_default_vector_size(op: ir.OpView):
517+
nonlocal default_vector_size
518+
for v in list(op.operands) + list(op.results):
519+
if ir.VectorType.isinstance(v.type):
520+
max_vec_size_for_v = (
521+
np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE
522+
)
523+
desired_vec_size = 8 // utils.bytewidth(v.type.element_type)
524+
default_vector_size = min(
525+
default_vector_size, max_vec_size_for_v, desired_vec_size
526+
)
527+
528+
for op in module.body:
529+
traverse_op(op, update_default_vector_size)
530+
531+
if default_vector_size is None: # Nothing to annotate.
532+
return
533+
492534
def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
493535
if not ir.VectorType.isinstance(ty):
494536
return None
495-
layout = fa.WGStridedFragLayout.from_shaped_type(ty)
537+
layout = fa.WGStridedFragLayout(
538+
shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size
539+
)
496540
return layouts_lib.to_strided_fragmented_layout_attr(layout)
497541

498542
def set_default_layout(op: ir.OpView):

tests/mosaic/gpu_dialect_test.py

+38
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from absl.testing import parameterized
2020
import jax
21+
from jax import numpy as jnp
2122
from jax._src import config
2223
from jax._src import test_util as jtu
2324
from jax._src.interpreters import mlir as mlir_interpreter
@@ -824,6 +825,43 @@ def body():
824825
)
825826
)
826827

828+
@parameterized.parameters(
829+
(arith.ExtFOp, jnp.bfloat16, jnp.float32),
830+
(arith.ExtSIOp, jnp.int16, jnp.int32),
831+
(arith.ExtUIOp, jnp.int16, jnp.uint32),
832+
(arith.FPToSIOp, jnp.float32, jnp.int32),
833+
(arith.FPToUIOp, jnp.float32, jnp.uint32),
834+
(arith.SIToFPOp, jnp.int16, jnp.float32),
835+
(arith.TruncFOp, jnp.float32, jnp.float16),
836+
(arith.TruncIOp, jnp.int32, jnp.int16),
837+
(arith.UIToFPOp, jnp.uint32, jnp.float32),
838+
)
839+
def test_lower_conversion_op_lowers_to_same_op(self, op, in_dtype, out_dtype):
840+
shape = (4, 32)
841+
842+
with ir.InsertionPoint(self.module.body):
843+
scalar_in_ty = mgpu_utils.dtype_to_ir_type(in_dtype)
844+
scalar_out_ty = mgpu_utils.dtype_to_ir_type(out_dtype)
845+
in_ty = ir.VectorType.get(shape, scalar_in_ty)
846+
out_ty = ir.VectorType.get(shape, scalar_out_ty)
847+
if ir.IntegerType.isinstance(scalar_in_ty):
848+
zero = ir.IntegerAttr.get(scalar_in_ty, 0)
849+
else:
850+
zero = ir.FloatAttr.get(scalar_in_ty, 0)
851+
splat_zero = arith.ConstantOp(
852+
in_ty, ir.DenseElementsAttr.get_splat(in_ty, zero)
853+
)
854+
op(out_ty, splat_zero)
855+
856+
mgpu.infer_layout(self.module)
857+
mgpu.lower_mgpu_dialect(self.module, None)
858+
859+
conversion_ops = find_if(self.module, lambda o: isinstance(o, op))
860+
# This is a splat, so we expect a single conversion op involving a scalar
861+
# after lowering.
862+
self.assertLen(conversion_ops, 1)
863+
self.assertEqual(conversion_ops[0].result.type, scalar_out_ty)
864+
827865

828866
if __name__ == "__main__":
829867
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)