Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support of NANOO fp8. #3993

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from .batch_apply import BatchApply as BatchApply
from .combinators import Sequential as Sequential
from .fp8_ops import Fp8DotGeneralOp as Fp8DotGeneralOp
levskaya marked this conversation as resolved.
Show resolved Hide resolved
from .fp8_ops import NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp
from .initializers import (
ones_init as ones_init,
ones as ones,
Expand Down
40 changes: 25 additions & 15 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import warnings
from functools import partial

from typing import Any
DType = Any

import jax
from jax import custom_jvp, custom_vjp, lax, random
from jax import numpy as jnp
Expand Down Expand Up @@ -117,10 +120,10 @@ def __repr__(self) -> str:
fm32 = fp8_meta_dtype_wrapper(jnp.float32)

def get_fp8_max(fp8_dtype, out_dtype):
assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2)
assert fp8_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2,
jnp.float8_e4m3fnuz, jnp.float8_e5m2fnuz)
return jnp.finfo(fp8_dtype).max.astype(out_dtype)


def quantize(x, q_dtype, scale, compute_dtype):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
Expand Down Expand Up @@ -181,22 +184,22 @@ def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
return qx, new_scale, new_history


@partial(custom_vjp, nondiff_argnums=(0,))
def in_qdq(compute_dtype, inp, scale, amax_history):
@partial(custom_vjp, nondiff_argnums=(0, 1))
def in_qdq(compute_dtype, q_dtype, inp, scale, amax_history):
qin, _, _ = qdq_and_return(
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin


def in_qdq_fwd(compute_dtype, inp, scale, amax_history):
def in_qdq_fwd(compute_dtype, q_dtype, inp, scale, amax_history):
qin, new_scale, new_history = qdq_and_return(
inp, jnp.float8_e4m3fn, scale, amax_history, compute_dtype
inp, q_dtype, scale, amax_history, compute_dtype
)
return qin, (new_scale, new_history)


def in_qdq_bwd(compute_dtype, res, g):
def in_qdq_bwd(compute_dtype, q_dtype, res, g):
new_scale, new_history = res
q_g = g
return q_g, new_scale, new_history
Expand All @@ -205,19 +208,19 @@ def in_qdq_bwd(compute_dtype, res, g):
in_qdq.defvjp(in_qdq_fwd, in_qdq_bwd)


@partial(custom_vjp, nondiff_argnums=(0,))
def out_qdq(compute_dtype, out, scale, amax_history):
@partial(custom_vjp, nondiff_argnums=(0, 1))
def out_qdq(compute_dtype, q_dtype, out, scale, amax_history):
return out


def out_qdq_fwd(compute_dtype, out, scale, amax_history):
def out_qdq_fwd(compute_dtype, q_dtype, out, scale, amax_history):
return out, (scale, amax_history)


def out_qdq_bwd(compute_dtype, res, g):
def out_qdq_bwd(compute_dtype, q_dtype, res, g):
scale, amax_history = res
q_g, new_scale, new_history = qdq_and_return(
g, jnp.float8_e5m2, scale, amax_history, compute_dtype
g, q_dtype, scale, amax_history, compute_dtype
)
return q_g, new_scale, new_history

Expand Down Expand Up @@ -260,6 +263,8 @@ def dot_general_with_precision_jvp(

class Fp8DotGeneralOp(module.Module):
amax_history_length: int = 1024
e4m3_dtype: DType = jnp.float8_e4m3fn
e5m2_dtype: DType = jnp.float8_e5m2

def setup(self) -> None:
scale_args = (
Expand Down Expand Up @@ -307,17 +312,22 @@ def __call__(self, *args, **kwargs):
x = jnp.asarray(x, comp_dtype)

x_qdq = in_qdq(
comp_dtype, x, self.input_scale.value, self.input_amax_history.value
comp_dtype, self.e4m3_dtype, x, self.input_scale.value, self.input_amax_history.value
)
k_qdq = in_qdq(
comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
comp_dtype, self.e4m3_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)
y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore
y = out_qdq(
comp_dtype,
self.e5m2_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value,
)

return y # type: ignore

class NANOOFp8DotGeneralOp(Fp8DotGeneralOp):
e4m3_dtype: DType = jnp.float8_e4m3fnuz
e5m2_dtype: DType = jnp.float8_e5m2fnuz
69 changes: 53 additions & 16 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,9 +1243,21 @@ def test_hashable(self):
self.assertNotEqual(hash(id1), hash(id1c))
self.assertNotEqual(hash(id1), hash(id1dc))

def get_fp8_dtypes(fp8_genre):
assert fp8_genre in ('OCP', 'NANOO')
if fp8_genre == 'OCP':
e4m3_dtype = jnp.float8_e4m3fn
e5m2_dtype = jnp.float8_e5m2
else: # fp8_genre == 'NANOO'
e4m3_dtype = jnp.float8_e4m3fnuz
e5m2_dtype = jnp.float8_e5m2fnuz
return e4m3_dtype, e5m2_dtype

class Fp8Test(parameterized.TestCase):
def test_fp8_dot_general_injection(self):
@parameterized.parameters(
{'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'}
)
def test_fp8_dot_general_injection(self, fp8_genre):
# Used to cast the inputs to be representable in FP8, so that the difference
# of the results from the original gemm and fp8 gemm is small.
cast_to_representable = functools.partial(
Expand All @@ -1254,18 +1266,23 @@ def test_fp8_dot_general_injection(self):
compute_dtype=jnp.float32,
)

e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre)

init_key, random_key = random.split(random.PRNGKey(seed=123), 2)
x = cast_to_representable(
random.uniform(random_key, (16, 32)), jnp.float8_e4m3fn
random.uniform(random_key, (16, 32)), e4m3_dtype
)
dy = cast_to_representable(
random.uniform(random_key, (16, 64)), jnp.float8_e5m2
random.uniform(random_key, (16, 64)), e5m2_dtype
)

def run(fp8_injection, expected_shapes):
p = nn.DenseGeneral(features=64, name='dense')
if fp8_injection:
p.dot_general_cls = nn.Fp8DotGeneralOp
if fp8_genre == 'OCP':
p.dot_general_cls = nn.Fp8DotGeneralOp
else:
p.dot_general_cls = nn.NANOOFp8DotGeneralOp

init_fn = jax.jit(p.init_with_output)
y, initial_vars = init_fn(init_key, x)
Expand All @@ -1284,10 +1301,14 @@ def _train(variables, x):
expected_shapes_original = {
'params': {'kernel': (32, 64), 'bias': (64,)},
}
if fp8_genre == 'OCP':
fp8_op_name = 'Fp8DotGeneralOp_0'
else:
fp8_op_name = 'NANOOFp8DotGeneralOp_0'
expected_shapes_new = {
'params': {'kernel': (32, 64), 'bias': (64,)},
fp8_ops.OVERWRITE_WITH_GRADIENT: {
'Fp8DotGeneralOp_0': {
fp8_op_name: {
'input_amax_history': (1024,),
'kernel_amax_history': (1024,),
'output_grad_amax_history': (1024,),
Expand All @@ -1307,11 +1328,20 @@ def _train(variables, x):
np.testing.assert_allclose(dw1, dw2, atol=1e-04)
np.testing.assert_allclose(dx1, dx2, atol=1e-04)

def test_fp8_train_state(self):
@parameterized.parameters(
{'fp8_genre': 'OCP'}, {'fp8_genre': 'NANOO'}
)
def test_fp8_train_state(self, fp8_genre):
key, init_key, random_key = random.split(random.PRNGKey(seed=123), 3)
x = random.uniform(random_key, (16, 16), dtype=jnp.float32)
if fp8_genre == 'OCP':
fp8_dot_op = nn.Fp8DotGeneralOp
fp8_op_name = 'Fp8DotGeneralOp_0'
else:
fp8_dot_op = nn.NANOOFp8DotGeneralOp
fp8_op_name = 'NANOOFp8DotGeneralOp_0'
dense = nn.DenseGeneral(
features=32, use_bias=True, dot_general_cls=nn.Fp8DotGeneralOp
features=32, use_bias=True, dot_general_cls=fp8_dot_op
)

init_fn = jax.jit(dense.init)
Expand Down Expand Up @@ -1340,8 +1370,9 @@ def loss_fn(vars):
scale_x, amax_history_x = jnp.ones(()), jnp.zeros((1024,))
scale_k, amax_history_k = jnp.ones(()), jnp.zeros((1024,))
scale_g, amax_history_g = jnp.ones(()), jnp.zeros((1024,))
e4m3_max = jnp.finfo(jnp.float8_e4m3fn).max.astype(jnp.float32)
e5m2_max = jnp.finfo(jnp.float8_e5m2).max.astype(jnp.float32)
e4m3_dtype, e5m2_dtype = get_fp8_dtypes(fp8_genre)
e4m3_max = jnp.finfo(e4m3_dtype).max.astype(jnp.float32)
e5m2_max = jnp.finfo(e5m2_dtype).max.astype(jnp.float32)

for _ in range(5):
key, random_key = random.split(key, 2)
Expand All @@ -1364,7 +1395,7 @@ def loss_fn(vars):

rtol, atol = 0.001, 0.001
fp8_vars = state.params[fp8_ops.OVERWRITE_WITH_GRADIENT][
'Fp8DotGeneralOp_0'
fp8_op_name
]
np.testing.assert_allclose(
fp8_vars['input_amax_history'],
Expand All @@ -1389,12 +1420,19 @@ def loss_fn(vars):
np.testing.assert_allclose(fp8_vars['kernel_scale'][0], scale_k)
np.testing.assert_allclose(fp8_vars['output_grad_scale'][0], scale_g)

@parameterized.parameters([True, False])
def test_fp8_meta_dtype(self, use_jit):
@parameterized.parameters(
{'fp8_genre': 'OCP', 'use_jit': True},
{'fp8_genre': 'OCP', 'use_jit': False},
{'fp8_genre': 'NANOO', 'use_jit': True},
{'fp8_genre': 'NANOO', 'use_jit': False}
)
def test_fp8_meta_dtype(self, fp8_genre, use_jit):
if not use_jit and not fp8_ops.CAN_USE_EARRAY:
self.skipTest("TODO: requires newer jax that has earray")
f32 = jnp.dtype('float32')
fm32 = fp8_ops.fm32
e4m3_dtype, _ = get_fp8_dtypes(fp8_genre)
e4m3_max = 448 if fp8_genre == 'OCP' else 240

# Create a scan loop with reused ah_f32 and sf_f32. So, the autograd will
# accumulate the grads of them. We expect the max op (rather than add op)
Expand All @@ -1404,7 +1442,7 @@ def outer(x, ah_f32, sf_f32):
sf_fm32 = jax.lax.convert_element_type(sf_f32, fm32)
array_x = jnp.array([x], f32)
def body_fun(carry, _):
carry = fp8_ops.in_qdq(f32, carry, sf_fm32, ah_fm32)
carry = fp8_ops.in_qdq(f32, e4m3_dtype, carry, sf_fm32, ah_fm32)
return carry, None
array_x, _ = jax.lax.scan(body_fun, array_x, None, length=3)
return array_x[0]
Expand All @@ -1421,12 +1459,11 @@ def body_fun(carry, _):
# 2nd iteration
grads, new_ah, new_sf = outer_fn(3., new_ah, new_sf)
np.testing.assert_allclose(new_ah, [3., 0., 2.])
np.testing.assert_allclose(new_sf, [2. / 448])
np.testing.assert_allclose(new_sf, [2. / e4m3_max])
# 3rd iteration
grads, new_ah, new_sf = outer_fn(4., new_ah, new_sf)
np.testing.assert_allclose(new_ah, [4., 2., 3.])
np.testing.assert_allclose(new_sf, [3. / 448])

np.testing.assert_allclose(new_sf, [3. / e4m3_max])

if __name__ == '__main__':
levskaya marked this conversation as resolved.
Show resolved Hide resolved
absltest.main()
Loading