diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index a1fc39a673..e4b368e72d 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from functools import partial -from jax import custom_vjp, lax, random +from jax import custom_jvp, custom_vjp, lax, random from jax import numpy as jnp from flax.linen import initializers, module @@ -120,6 +121,39 @@ def out_qdq_bwd(compute_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) +@partial(custom_jvp, nondiff_argnums=(2, 3, 4)) +def dot_general_with_precision( + lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None +): + if precision != None or preferred_element_type != None: + warnings.warn( + 'The function dot_general_with_precision will set the ' + 'precision/preferred_element_type and disregard any provided ' + 'values.' + ) + return lax.dot_general( + lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT + ) + + +@dot_general_with_precision.defjvp +def dot_general_with_precision_jvp( + dimension_numbers, precision, preferred_element_type, primals, tangents +): + lhs, rhs = primals + lhs_dot, rhs_dot = tangents + + out = lax.dot_general( + lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT + ) + grad_out = lax.dot_general( + lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST + ) + lax.dot_general( + lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST + ) + return out, grad_out + + class Fp8DotGeneralOp(module.Module): amax_history_length: int = 1024 @@ -162,7 +196,6 @@ def __call__(self, *args, **kwargs): x = args[0] k = args[1] dimension_numbers = args[2] - precision = kwargs['precision'] # Use the `k.dtype` since it aligns with the `dtype` of its layers, # namely, the computation data type. @@ -175,7 +208,7 @@ def __call__(self, *args, **kwargs): k_qdq = in_qdq( comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) - y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision) # type: ignore + y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore y = out_qdq( comp_dtype, y_qdq,