Skip to content

Commit

Permalink
Merge pull request #3440 from chiamp:fix_fp8
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576953704
  • Loading branch information
Flax Authors committed Oct 26, 2023
2 parents 738078c + e7e33e1 commit 9064108
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ def __call__(self, *args, **kwargs) -> jnp.ndarray:
k_qdq = in_qdq(
comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)
y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision)
y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision) # type: ignore
y = out_qdq(
comp_dtype,
y_qdq,
self.output_grad_scale.value,
self.output_grad_amax_history.value
)

return y
return y # type: ignore

0 comments on commit 9064108

Please sign in to comment.