Skip to content

Commit

Permalink
Plumb dot dimension numbers into TPU matmul op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689442832
  • Loading branch information
Google-ML-Automation committed Nov 6, 2024
1 parent dc33a28 commit 749c5e4
Showing 1 changed file with 85 additions and 17 deletions.
102 changes: 85 additions & 17 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,80 @@ def _proxy_fun(val, *, shape, broadcast_dimensions):
lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule


def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
"""Converts a jax dot dimension numbers to a tpu dot dimension numbers.
Jax dot dimension numbers are given as a tuple of tuples of sequences of ints
of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
rhs_batch_dims)).
TPU dot dimension numbers are given as an MLIR definition of the form
#tpu.dot_dimension_numbers<lhs_contracting_dims, rhs_contracting_dims,
lhs_non_contracting_dims, rhs_non_contracting_dims, output_dim_order,
lhs_batch_dims, rhs_batch_dims>
this function converts the jax form to the tpu form, keeping the batch dims,
and creating a new output_dim_order based on the contracting, batch, and
non-contracting dims.
"""
(contracting_dims, batch_dims) = dimension_numbers
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
lhs_batch_dims, rhs_batch_dims = batch_dims

lhs_total_dims = set(range(len(lhs_shape)))
rhs_total_dims = set(range(len(rhs_shape)))

lhs_non_contracting_dims = sorted(
lhs_total_dims - set(lhs_contracting_dims) - set(lhs_batch_dims)
)
rhs_non_contracting_dims = sorted(
rhs_total_dims - set(rhs_contracting_dims) - set(rhs_batch_dims)
)

# Create output_dim_order
# Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims,
# rhs_non_contracting_dims
output_dim_order = []

lhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(lhs_shape)))}
rhs_dim_map = {
dim: idx for idx, dim in enumerate(range(len(rhs_shape)))
}

for dim in lhs_batch_dims:
output_dim_order.append(lhs_dim_map[dim])

for dim in lhs_non_contracting_dims:
output_dim_order.append(0)
output_dim_order.append(lhs_dim_map[dim])

for dim in rhs_non_contracting_dims:
output_dim_order.append(1)
output_dim_order.append(rhs_dim_map[dim])

def format_dims(dims):
return "[" + ", ".join(str(d) for d in dims) + "]"

tpu_dim_numbers_str = (
"#tpu.dot_dimension_numbers<"
+ format_dims(lhs_contracting_dims)
+ ", "
+ format_dims(rhs_contracting_dims)
+ ", "
+ format_dims(lhs_non_contracting_dims)
+ ", "
+ format_dims(rhs_non_contracting_dims)
+ ", "
+ format_dims(output_dim_order)
+ ", "
+ format_dims(lhs_batch_dims)
+ ", "
+ format_dims(rhs_batch_dims)
+ ">"
)

return ir.Attribute.parse(tpu_dim_numbers_str)


def _dot_general_lowering_rule(
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
):
Expand All @@ -1589,7 +1663,7 @@ def _dot_general_lowering_rule(
raise NotImplementedError(
f"Only 2D tensors supported in dot; received: {ctx.avals_in}"
)
lhs_aval, _ = ctx.avals_in
lhs_aval, rhs_aval = ctx.avals_in
# This is really a matrix-vector product. It only looks like matrix-matrix.
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
Expand All @@ -1615,19 +1689,10 @@ def _dot_general_lowering_rule(
)
return vector.shape_cast(out_type, red)

# TODO(mvoz): Plumb these into dot dimension numbers on the matmul op!
if lhs_dims == (1,):
transpose_lhs = False
elif lhs_dims == (0,):
transpose_lhs = True
else:
raise NotImplementedError
if rhs_dims == (0,):
transpose_rhs = False
elif rhs_dims == (1,):
transpose_rhs = True
else:
raise NotImplementedError
tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims(
dimension_numbers, lhs_aval.shape, rhs_aval.shape
)

if precision is not None:
if precision[0] != precision[1]:
raise NotImplementedError("Per-operand dot precision unsupported")
Expand All @@ -1644,9 +1709,12 @@ def _dot_general_lowering_rule(
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
)
return tpu.matmul(
out_type, x, y, out_tile,
transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs,
precision=precision_attr
out_type,
x,
y,
out_tile,
dimension_numbers=tpu_dot_dims,
precision=precision_attr,
)


Expand Down

0 comments on commit 749c5e4

Please sign in to comment.