diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index cd29659c1c..f01ed92880 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -87,6 +87,7 @@ from .batch_apply import BatchApply as BatchApply from .combinators import Sequential as Sequential from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp +from .fp8_ops import NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp from .initializers import ( ones_init as ones_init, ones as ones, diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 1cfc248a7b..6fd94e9b87 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -17,6 +17,9 @@ import warnings from functools import partial +from typing import Any +DType = Any + import jax from jax import custom_jvp, custom_vjp, lax, random from jax import numpy as jnp @@ -117,10 +120,10 @@ def __repr__(self) -> str: fm32 = fp8_meta_dtype_wrapper(jnp.float32) def get_fp8_max(fp8_dtype, out_dtype): - assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2) + assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2, + jnp.float8_e4m3fnuz, jnp.float8_e5m2fnuz) return jnp.finfo(fp8_dtype).max.astype(out_dtype) - def quantize(x, q_dtype, scale, compute_dtype): # Explicitly cast the max values to the compute dtype to avoid unnecessary # casting to FP32 during the subsequent math operations." @@ -181,22 +184,22 @@ def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): return qx, new_scale, new_history -@partial(custom_vjp, nondiff_argnums=(0,)) -def in_qdq(compute_dtype, inp, scale, amax_history): +@partial(custom_vjp, nondiff_argnums=(0, 1)) +def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history): qin, _, _ = qdq_and_return( - inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype + inp, q_dtype, scale, amax_history, compute_dtype ) return qin -def in_qdq_fwd(compute_dtype, inp, scale, amax_history): +def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history): qin, new_scale, new_history = qdq_and_return( - inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype + inp, q_dtype, scale, amax_history, compute_dtype ) return qin, (new_scale, new_history) -def in_qdq_bwd(compute_dtype, res, g): +def in_qdq_bwd(compute_dtype, q_dtype, res, g): new_scale, new_history = res q_g = g return q_g, new_scale, new_history @@ -205,19 +208,19 @@ def in_qdq_bwd(compute_dtype, res, g): in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd) -@partial(custom_vjp, nondiff_argnums=(0,)) -def out_qdq(compute_dtype, out, scale, amax_history): +@partial(custom_vjp, nondiff_argnums=(0, 1)) +def out_qdq(compute_dtype, q_dtype, out, scale, amax_history): return out -def out_qdq_fwd(compute_dtype, out, scale, amax_history): +def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history): return out, (scale, amax_history) -def out_qdq_bwd(compute_dtype, res, g): +def out_qdq_bwd(compute_dtype, q_dtype, res, g): scale, amax_history = res q_g, new_scale, new_history = qdq_and_return( - g, jnp.float8_e5m2, scale, amax_history, compute_dtype + g, q_dtype, scale, amax_history, compute_dtype ) return q_g, new_scale, new_history @@ -260,6 +263,8 @@ def dot_general_with_precision_jvp( class Fp8DotGeneralOp(module.Module): amax_history_length: int = 1024 + e4m3_dtype: DType = jnp.float8_e4m3fn + e5m2_dtype: DType = jnp.float8_e5m2 def setup(self) -> None: scale_args = ( @@ -307,17 +312,22 @@ def __call__(self, *args, **kwargs): x = jnp.asarray(x, comp_dtype) x_qdq = in_qdq( - comp_dtype, x, self.input_scale.value, self.input_amax_history.value + comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value ) k_qdq = in_qdq( - comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value + comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore y = out_qdq( comp_dtype, + self.e5m2_dtype, y_qdq, self.output_grad_scale.value, self.output_grad_amax_history.value, ) return y # type: ignore + +class NANOOFp8DotGeneralOp(Fp8DotGeneralOp): + e4m3_dtype: DType = jnp.float8_e4m3fnuz + e5m2_dtype: DType = jnp.float8_e5m2fnuz diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 771ebb8983..f49afcdd66 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -1243,9 +1243,21 @@ def test_hashable(self): self.assertNotEqual(hash(id1), hash(id1c)) self.assertNotEqual(hash(id1), hash(id1dc)) +def get_fp8_dtypes(fp8_genre): + assert fp8_genre in ('OCP', 'NANOO') + if fp8_genre == 'OCP': + e4m3_dtype = jnp.float8_e4m3fn + e5m2_dtype = jnp.float8_e5m2 + else: # fp8_genre == 'NANOO' + e4m3_dtype = jnp.float8_e4m3fnuz + e5m2_dtype = jnp.float8_e5m2fnuz + return e4m3_dtype, e5m2_dtype class Fp8Test(parameterized.TestCase): - def test_fp8_dot_general_injection(self): + @parameterized.parameters( + {'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'} + ) + def test_fp8_dot_general_injection(self, fp8_genre): # Used to cast the inputs to be representable in FP8, so that the difference # of the results from the original gemm and fp8 gemm is small. cast_to_representable = functools.partial( @@ -1254,18 +1266,23 @@ def test_fp8_dot_general_injection(self): compute_dtype=jnp.float32, ) + e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre) + init_key, random_key = random.split(random.PRNGKey(seed=123), 2) x = cast_to_representable( - random.uniform(random_key, (16, 32)), jnp.float8_e4m3fn + random.uniform(random_key, (16, 32)), e4m3_dtype ) dy = cast_to_representable( - random.uniform(random_key, (16, 64)), jnp.float8_e5m2 + random.uniform(random_key, (16, 64)), e5m2_dtype ) def run(fp8_injection, expected_shapes): p = nn.DenseGeneral(features=64, name='dense') if fp8_injection: - p.dot_general_cls = nn.Fp8DotGeneralOp + if fp8_genre == 'OCP': + p.dot_general_cls = nn.Fp8DotGeneralOp + else: + p.dot_general_cls = nn.NANOOFp8DotGeneralOp init_fn = jax.jit(p.init_with_output) y, initial_vars = init_fn(init_key, x) @@ -1284,10 +1301,14 @@ def _train(variables, x): expected_shapes_original = { 'params': {'kernel': (32, 64), 'bias': (64,)}, } + if fp8_genre == 'OCP': + fp8_op_name = 'Fp8DotGeneralOp_0' + else: + fp8_op_name = 'NANOOFp8DotGeneralOp_0' expected_shapes_new = { 'params': {'kernel': (32, 64), 'bias': (64,)}, fp8_ops.OVERWRITE_WITH_GRADIENT: { - 'Fp8DotGeneralOp_0': { + fp8_op_name: { 'input_amax_history': (1024,), 'kernel_amax_history': (1024,), 'output_grad_amax_history': (1024,), @@ -1307,11 +1328,20 @@ def _train(variables, x): np.testing.assert_allclose(dw1, dw2, atol=1e-04) np.testing.assert_allclose(dx1, dx2, atol=1e-04) - def test_fp8_train_state(self): + @parameterized.parameters( + {'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'} + ) + def test_fp8_train_state(self, fp8_genre): key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3) x = random.uniform(random_key, (16, 16), dtype=jnp.float32) + if fp8_genre == 'OCP': + fp8_dot_op = nn.Fp8DotGeneralOp + fp8_op_name = 'Fp8DotGeneralOp_0' + else: + fp8_dot_op = nn.NANOOFp8DotGeneralOp + fp8_op_name = 'NANOOFp8DotGeneralOp_0' dense = nn.DenseGeneral( - features=32, use_bias=True, dot_general_cls=nn.Fp8DotGeneralOp + features=32, use_bias=True, dot_general_cls=fp8_dot_op ) init_fn = jax.jit(dense.init) @@ -1340,8 +1370,9 @@ def loss_fn(vars): scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,)) scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,)) scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,)) - e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32) - e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32) + e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre) + e4m3_max = jnp.finfo(e4m3_dtype).max.astype(jnp.float32) + e5m2_max = jnp.finfo(e5m2_dtype).max.astype(jnp.float32) for _ in range(5): key, random_key = random.split(key, 2) @@ -1364,7 +1395,7 @@ def loss_fn(vars): rtol, atol = 0.001, 0.001 fp8_vars = state.params[fp8_ops.OVERWRITE_WITH_GRADIENT][ - 'Fp8DotGeneralOp_0' + fp8_op_name ] np.testing.assert_allclose( fp8_vars['input_amax_history'], @@ -1389,12 +1420,19 @@ def loss_fn(vars): np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k) np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g) - @parameterized.parameters([True, False]) - def test_fp8_meta_dtype(self, use_jit): + @parameterized.parameters( + {'fp8_genre': 'OCP', 'use_jit': True}, + {'fp8_genre': 'OCP', 'use_jit': False}, + {'fp8_genre': 'NANOO', 'use_jit': True}, + {'fp8_genre': 'NANOO', 'use_jit': False} + ) + def test_fp8_meta_dtype(self, fp8_genre, use_jit): if not use_jit and not fp8_ops.CAN_USE_EARRAY: self.skipTest("TODO: requires newer jax that has earray") f32 = jnp.dtype('float32') fm32 = fp8_ops.fm32 + e4m3_dtype, _ = get_fp8_dtypes(fp8_genre) + e4m3_max = 448 if fp8_genre == 'OCP' else 240 # Create a scan loop with reused ah_f32 and sf_f32. So, the autograd will # accumulate the grads of them. We expect the max op (rather than add op) @@ -1404,7 +1442,7 @@ def outer(x, ah_f32, sf_f32): sf_fm32 = jax.lax.convert_element_type(sf_f32, fm32) array_x = jnp.array([x], f32) def body_fun(carry, _): - carry = fp8_ops.in_qdq(f32, carry, sf_fm32, ah_fm32) + carry = fp8_ops.in_qdq(f32, e4m3_dtype, carry, sf_fm32, ah_fm32) return carry, None array_x, _ = jax.lax.scan(body_fun, array_x, None, length=3) return array_x[0] @@ -1421,12 +1459,11 @@ def body_fun(carry, _): # 2nd iteration grads, new_ah, new_sf = outer_fn(3., new_ah, new_sf) np.testing.assert_allclose(new_ah, [3., 0., 2.]) - np.testing.assert_allclose(new_sf, [2. / 448]) + np.testing.assert_allclose(new_sf, [2. / e4m3_max]) # 3rd iteration grads, new_ah, new_sf = outer_fn(4., new_ah, new_sf) np.testing.assert_allclose(new_ah, [4., 2., 3.]) - np.testing.assert_allclose(new_sf, [3. / 448]) - + np.testing.assert_allclose(new_sf, [3. / e4m3_max]) if __name__ == '__main__': absltest.main()