Skip to content

Commit

Permalink
use fmax32
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Jul 2, 2024
1 parent 16f2866 commit 798cfe7
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,10 @@ def compute_amax_history(x, amax_history):


def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32
# convert fm32->f32 so we can do math
if is_fm32:
is_fmax32 = (scale.dtype == fp32_max_grad and
amax_history.dtype == fp32_max_grad)
# convert fmax32->f32 so we can do math
if is_fmax32:
amax_history = lax.convert_element_type(amax_history, jnp.float32)
scale = lax.convert_element_type(scale, jnp.float32)

Expand All @@ -178,10 +179,10 @@ def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):

new_history = compute_amax_history(x, amax_history)

# convert f32->fm32 so the autodiff system accumulates fp8 meta correctly
if is_fm32:
new_history = lax.convert_element_type(new_history, fm32)
new_scale = lax.convert_element_type(new_scale, fm32)
# convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly
if is_fmax32:
new_history = lax.convert_element_type(new_history, fp32_max_grad)
new_scale = lax.convert_element_type(new_scale, fp32_max_grad)
return qx, new_scale, new_history


Expand Down

0 comments on commit 798cfe7

Please sign in to comment.