diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 3fb667b8d..490647186 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -164,9 +164,10 @@ def compute_amax_history(x, amax_history): def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): - is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32 - # convert fm32->f32 so we can do math - if is_fm32: + is_fmax32 = (scale.dtype == fp32_max_grad and + amax_history.dtype == fp32_max_grad) + # convert fmax32->f32 so we can do math + if is_fmax32: amax_history = lax.convert_element_type(amax_history, jnp.float32) scale = lax.convert_element_type(scale, jnp.float32) @@ -178,10 +179,10 @@ def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype): new_history = compute_amax_history(x, amax_history) - # convert f32->fm32 so the autodiff system accumulates fp8 meta correctly - if is_fm32: - new_history = lax.convert_element_type(new_history, fm32) - new_scale = lax.convert_element_type(new_scale, fm32) + # convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly + if is_fmax32: + new_history = lax.convert_element_type(new_history, fp32_max_grad) + new_scale = lax.convert_element_type(new_scale, fp32_max_grad) return qx, new_scale, new_history