Skip to content

Commit

Permalink
[Mosaic TPU] Add support for true divide in bf16 on TPUv6
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707877453
  • Loading branch information
apaszke authored and Google-ML-Automation committed Dec 19, 2024
1 parent ad00ec1 commit 761ae91
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
6 changes: 2 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,8 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
return failure();
}
auto element_type = ty.getElementType();
// PowFOp and DivFOp do not seem to be supported in bf16 on later
// hardware.
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
isa<arith::DivFOp>(op);
// PowFOp does not seem to be supported in bf16 on later hardware.
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op);
if (needs_cast && element_type.isBF16()) {
auto target_f32 =
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
Expand Down
13 changes: 11 additions & 2 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,18 +1106,27 @@ def kernel(x_ref, y_ref, out_ref):
@parameterized.parameters(
("int32", "float32"),
("float32", "float32"),
("bfloat16", "bfloat16"),
)
def test_true_divide(self, dtype, out_dtype):
if jtu.test_device_matches(["tpu"]):
if out_dtype == "bfloat16" and not jtu.is_device_tpu_at_least(6):
self.skipTest("bfloat16 is not supported on older TPU generations")
if not jtu.if_cloud_tpu_at_least(2024, 12, 21):
self.skipTest("Requires libtpu built after 2024-12-21")
@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8,), out_dtype),
out_shape=jax.ShapeDtypeStruct((8, 8), out_dtype),
)
def kernel(x_ref, y_ref, o_ref):
o_ref[...] = jnp.true_divide(x_ref[...], y_ref[...])

x = jnp.array([1, 3, -4, -6, 2, 5, 4, -7]).astype(dtype)
y = jnp.array([3, 1, -4, -5, 2, -2, 2, 4]).astype(dtype)
np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y))
x = jnp.repeat(x, 8, axis=0).reshape(8, 8)
y = jnp.tile(y, 8).reshape(8, 8)
rtol = 8e-3 if dtype == "bfloat16" else 1e-6
np.testing.assert_allclose(jnp.true_divide(x, y), kernel(x, y), rtol=rtol)

@parameterized.parameters("float16", "bfloat16")
def test_true_divide_unsupported(self, dtype):
Expand Down

0 comments on commit 761ae91

Please sign in to comment.