|
23 | 23 | from jax._src.lib import mosaic_gpu_dialect as mgpu
|
24 | 24 | from jax._src.lib.mlir import ir
|
25 | 25 | from jax._src.lib.mlir.dialects import arith
|
26 |
| -from jax._src.lib.mlir.dialects import scf |
27 | 26 | from jax._src.lib.mlir.dialects import memref
|
| 27 | +from jax._src.lib.mlir.dialects import scf |
28 | 28 | from jax._src.lib.mlir.dialects import vector
|
| 29 | +import math |
| 30 | +import numpy as np |
29 | 31 |
|
30 | 32 | from . import fragmented_array as fa
|
31 | 33 | from . import inference_utils
|
32 | 34 | from . import layouts as layouts_lib
|
| 35 | +from . import utils |
33 | 36 |
|
34 | 37 | # mypy: ignore-errors
|
35 | 38 |
|
@@ -192,22 +195,38 @@ def is_array(v: ir.Value) -> bool:
|
192 | 195 |
|
193 | 196 |
|
194 | 197 | for op in [
|
195 |
| - arith.AddIOp, arith.AddFOp, |
| 198 | + arith.AddIOp, |
| 199 | + arith.AddFOp, |
196 | 200 | arith.AndIOp,
|
197 | 201 | arith.BitcastOp,
|
198 | 202 | arith.CmpFOp,
|
199 | 203 | arith.CmpIOp,
|
200 |
| - arith.ExtFOp, arith.ExtSIOp, arith.ExtUIOp, |
| 204 | + arith.ExtFOp, |
| 205 | + arith.ExtSIOp, |
| 206 | + arith.ExtUIOp, |
| 207 | + arith.FPToSIOp, |
| 208 | + arith.FPToUIOp, |
201 | 209 | arith.MaximumFOp,
|
202 |
| - arith.MaxUIOp, arith.MaxSIOp, |
| 210 | + arith.MaxUIOp, |
| 211 | + arith.MaxSIOp, |
203 | 212 | arith.MinimumFOp,
|
204 |
| - arith.MinUIOp, arith.MinSIOp, |
205 |
| - arith.MulIOp, arith.MulFOp, |
| 213 | + arith.MinUIOp, |
| 214 | + arith.MinSIOp, |
| 215 | + arith.MulIOp, |
| 216 | + arith.MulFOp, |
206 | 217 | arith.OrIOp,
|
207 |
| - arith.FloorDivSIOp, arith.DivUIOp, arith.DivFOp, |
208 |
| - arith.RemUIOp, arith.RemSIOp, arith.RemFOp, |
209 |
| - arith.SubIOp, arith.SubFOp, |
210 |
| - arith.TruncFOp, arith.TruncIOp, |
| 218 | + arith.FloorDivSIOp, |
| 219 | + arith.DivUIOp, |
| 220 | + arith.DivFOp, |
| 221 | + arith.RemUIOp, |
| 222 | + arith.RemSIOp, |
| 223 | + arith.RemFOp, |
| 224 | + arith.SIToFPOp, |
| 225 | + arith.UIToFPOp, |
| 226 | + arith.SubIOp, |
| 227 | + arith.SubFOp, |
| 228 | + arith.TruncFOp, |
| 229 | + arith.TruncIOp, |
211 | 230 | arith.XOrIOp,
|
212 | 231 | vector.LoadOp,
|
213 | 232 | vector.StoreOp,
|
@@ -488,11 +507,36 @@ def inference_step(op: ir.Operation):
|
488 | 507 | # propagated. However, it is possible for some operations to remain
|
489 | 508 | # unannotated---for example, if there were no annotations on any operation in
|
490 | 509 | # the module at the start of this function. We annotate all the remaining ops
|
491 |
| - # that should be annotated with a strided fragmented layout. |
| 510 | + # that should be annotated with a strided fragmented layout, whose vector size |
| 511 | + # is derived from the narrowest type and vector size used in the program. We |
| 512 | + # make sure to derive a single vector size in order to avoid relayouts at |
| 513 | + # lowering time. |
| 514 | + default_vector_size = math.inf |
| 515 | + |
| 516 | + def update_default_vector_size(op: ir.OpView): |
| 517 | + nonlocal default_vector_size |
| 518 | + for v in list(op.operands) + list(op.results): |
| 519 | + if ir.VectorType.isinstance(v.type): |
| 520 | + max_vec_size_for_v = ( |
| 521 | + np.prod(cast(ir.ShapedType, v.type).shape) // fa.WARPGROUP_SIZE |
| 522 | + ) |
| 523 | + desired_vec_size = 8 // utils.bytewidth(v.type.element_type) |
| 524 | + default_vector_size = min( |
| 525 | + default_vector_size, max_vec_size_for_v, desired_vec_size |
| 526 | + ) |
| 527 | + |
| 528 | + for op in module.body: |
| 529 | + traverse_op(op, update_default_vector_size) |
| 530 | + |
| 531 | + if default_vector_size is None: # Nothing to annotate. |
| 532 | + return |
| 533 | + |
492 | 534 | def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
|
493 | 535 | if not ir.VectorType.isinstance(ty):
|
494 | 536 | return None
|
495 |
| - layout = fa.WGStridedFragLayout.from_shaped_type(ty) |
| 537 | + layout = fa.WGStridedFragLayout( |
| 538 | + shape=cast(ir.ShapedType, ty).shape, vec_size=default_vector_size |
| 539 | + ) |
496 | 540 | return layouts_lib.to_strided_fragmented_layout_attr(layout)
|
497 | 541 |
|
498 | 542 | def set_default_layout(op: ir.OpView):
|
|
0 commit comments