Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Feb 26, 2024
1 parent 54fa392 commit dd004c2
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ def compute_amax_history(x, amax_history):
return new_history


def qdq_and_return(x, q_dtype, sf_fm32, ah_fm32, compute_dtype):
def qdq_and_return(x, q_dtype, scale, amax_history, compute_dtype):
is_fm32 = scale.dtype == fm32 and amax_history.dtype == fm32
# convert fm32->f32 so we can do math
amax_history = lax.convert_element_type(ah_fm32, jnp.float32)
scale = lax.convert_element_type(sf_fm32, jnp.float32)
if is_fm32:
amax_history = lax.convert_element_type(amax_history, jnp.float32)
scale = lax.convert_element_type(scale, jnp.float32)

dtype_max = get_fp8_max(q_dtype, jnp.float32)
amax_from_history = jnp.max(amax_history, axis=0)
Expand All @@ -145,9 +147,10 @@ def qdq_and_return(x, q_dtype, sf_fm32, ah_fm32, compute_dtype):
new_history = compute_amax_history(x, amax_history)

# convert f32->fm32 so the autodiff system accumulates fp8 meta correctly
new_ah_fm32 = lax.convert_element_type(new_history, fm32)
new_sf_fm32 = lax.convert_element_type(new_scale, fm32)
return qx, new_sf_fm32, new_ah_fm32
if is_fm32:
new_history = lax.convert_element_type(new_history, fm32)
new_scale = lax.convert_element_type(new_scale, fm32)
return qx, new_scale, new_history


@partial(custom_vjp, nondiff_argnums=(0,))
Expand Down Expand Up @@ -275,18 +278,18 @@ def __call__(self, *args, **kwargs):
comp_dtype = k.dtype
x = jnp.asarray(x, comp_dtype)

x_sf_fm32 = lax.convert_element_type(self.input_scale.value, fm32)
x_ah_fm32 = lax.convert_element_type(self.input_amax_history.value, fm32)
k_sf_fm32 = lax.convert_element_type(self.kernel_scale.value, fm32)
k_ah_fm32 = lax.convert_element_type(self.kernel_amax_history.value, fm32)
g_sf_fm32 = lax.convert_element_type(self.output_grad_scale.value, fm32)
g_ah_fm32 = lax.convert_element_type(
self.output_grad_amax_history.value, fm32
x_qdq = in_qdq(
comp_dtype, x, self.input_scale.value, self.input_amax_history.value
)
k_qdq = in_qdq(
comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)

x_qdq = in_qdq(comp_dtype, x, x_sf_fm32, x_ah_fm32)
k_qdq = in_qdq(comp_dtype, k, k_sf_fm32, k_ah_fm32)
y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore
y = out_qdq(comp_dtype, y_qdq, g_sf_fm32, g_ah_fm32)
y = out_qdq(
comp_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value,
)

return y # type: ignore

0 comments on commit dd004c2

Please sign in to comment.