From 641bf6a339ad3628ae28694d3f49f245bfdf46d4 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 27 Oct 2023 13:31:58 -0700 Subject: [PATCH 1/5] Support FP8 GEMM fast accumulation. --- flax/linen/fp8_ops.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index a1fc39a673..e5ceb8c08b 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -14,7 +14,7 @@ from functools import partial -from jax import custom_vjp, lax, random +from jax import custom_vjp, custom_jvp, lax, random from jax import numpy as jnp from flax.linen import initializers, module @@ -120,6 +120,24 @@ def out_qdq_bwd(compute_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) +@partial(custom_jvp, nondiff_argnums=(2,)) +def dot_general_with_precision(lhs, rhs, dimension_numbers): + return lax.dot_general(lhs, rhs, dimension_numbers, + precision=lax.Precision.DEFAULT) + +@dot_general_with_precision.defjvp +def dot_general_with_precision_jvp(dimension_numbers, primals, tangents): + lhs, rhs = primals + lhs_dot, rhs_dot = tangents + + out = lax.dot_general(lhs, rhs, dimension_numbers, + precision=lax.Precision.DEFAULT) + grad_out = (lax.dot_general(lhs_dot, rhs, dimension_numbers, + precision=lax.Precision.HIGHEST) + + lax.dot_general(lhs, rhs_dot, dimension_numbers, + precision=lax.Precision.HIGHEST)) + return out, grad_out + class Fp8DotGeneralOp(module.Module): amax_history_length: int = 1024 @@ -162,7 +180,6 @@ def __call__(self, *args, **kwargs): x = args[0] k = args[1] dimension_numbers = args[2] - precision = kwargs['precision'] # Use the `k.dtype` since it aligns with the `dtype` of its layers, # namely, the computation data type. @@ -175,7 +192,7 @@ def __call__(self, *args, **kwargs): k_qdq = in_qdq( comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) - y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision) # type: ignore + y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore y = out_qdq( comp_dtype, y_qdq, From 8d84c500c1babf1997a9546068a153392943a69b Mon Sep 17 00:00:00 2001 From: shuw Date: Thu, 2 Nov 2023 19:25:54 -0700 Subject: [PATCH 2/5] Improve based on review #1 --- flax/linen/fp8_ops.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index e5ceb8c08b..078b330ae1 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + from functools import partial -from jax import custom_vjp, custom_jvp, lax, random +from jax import custom_jvp, custom_vjp, lax, random from jax import numpy as jnp from flax.linen import initializers, module @@ -120,13 +122,19 @@ def out_qdq_bwd(compute_dtype, res, g): out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd) -@partial(custom_jvp, nondiff_argnums=(2,)) -def dot_general_with_precision(lhs, rhs, dimension_numbers): +@partial(custom_jvp, nondiff_argnums=(2, 3, 4)) +def dot_general_with_precision(lhs, rhs, dimension_numbers, precision=None, + preferred_element_type=None): + if precision != None or preferred_element_type != None: + warnings.warn("The function dot_general_with_precision will set the " + "precision/preferred_element_type and disregard any provided " + "values.") return lax.dot_general(lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT) @dot_general_with_precision.defjvp -def dot_general_with_precision_jvp(dimension_numbers, primals, tangents): +def dot_general_with_precision_jvp(dimension_numbers, precision, + preferred_element_type, primals, tangents): lhs, rhs = primals lhs_dot, rhs_dot = tangents From 786a11162c2910ba73e11b4c0b98bdaba1f14ff4 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 3 Nov 2023 10:20:48 -0700 Subject: [PATCH 3/5] Fix format --- flax/linen/fp8_ops.py | 53 ++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 078b330ae1..012935d0bd 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -13,7 +13,6 @@ # limitations under the License. import warnings - from functools import partial from jax import custom_jvp, custom_vjp, lax, random @@ -123,28 +122,36 @@ def out_qdq_bwd(compute_dtype, res, g): @partial(custom_jvp, nondiff_argnums=(2, 3, 4)) -def dot_general_with_precision(lhs, rhs, dimension_numbers, precision=None, - preferred_element_type=None): +def dot_general_with_precision( + lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None +): if precision != None or preferred_element_type != None: - warnings.warn("The function dot_general_with_precision will set the " - "precision/preferred_element_type and disregard any provided " - "values.") - return lax.dot_general(lhs, rhs, dimension_numbers, - precision=lax.Precision.DEFAULT) - -@dot_general_with_precision.defjvp -def dot_general_with_precision_jvp(dimension_numbers, precision, - preferred_element_type, primals, tangents): - lhs, rhs = primals - lhs_dot, rhs_dot = tangents - - out = lax.dot_general(lhs, rhs, dimension_numbers, - precision=lax.Precision.DEFAULT) - grad_out = (lax.dot_general(lhs_dot, rhs, dimension_numbers, - precision=lax.Precision.HIGHEST) + - lax.dot_general(lhs, rhs_dot, dimension_numbers, - precision=lax.Precision.HIGHEST)) - return out, grad_out + warnings.warn( + 'The function dot_general_with_precision will set the ' + 'precision/preferred_element_type and disregard any provided ' + 'values.' + ) + return lax.dot_general( + lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT + ) + + +def dot_general_with_precision_jvp( + dimension_numbers, precision, preferred_element_type, primals, tangents +): + lhs, rhs = primals + lhs_dot, rhs_dot = tangents + + out = lax.dot_general( + lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT + ) + grad_out = lax.dot_general( + lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST + ) + lax.dot_general( + lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST + ) + return out, grad_out + class Fp8DotGeneralOp(module.Module): amax_history_length: int = 1024 @@ -200,7 +207,7 @@ def __call__(self, *args, **kwargs): k_qdq = in_qdq( comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value ) - y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore + y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore y = out_qdq( comp_dtype, y_qdq, From aa9427a304bc5c00a810605580c810aba9f12177 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 7 Nov 2023 10:24:47 -0800 Subject: [PATCH 4/5] Add missing decorator --- flax/linen/fp8_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 012935d0bd..7e2b387648 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -135,7 +135,7 @@ def dot_general_with_precision( lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT ) - +@dot_general_with_precision.defjvp def dot_general_with_precision_jvp( dimension_numbers, precision, preferred_element_type, primals, tangents ): From 09f2ec9af2b90546eff3c3733d70db7f10c85304 Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 7 Nov 2023 10:26:38 -0800 Subject: [PATCH 5/5] Fiix format --- flax/linen/fp8_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flax/linen/fp8_ops.py b/flax/linen/fp8_ops.py index 7e2b387648..e4b368e72d 100644 --- a/flax/linen/fp8_ops.py +++ b/flax/linen/fp8_ops.py @@ -135,6 +135,7 @@ def dot_general_with_precision( lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT ) + @dot_general_with_precision.defjvp def dot_general_with_precision_jvp( dimension_numbers, precision, preferred_element_type, primals, tangents