Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for fast accumulation selection for FP8 GEMM #3416

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from functools import partial

from jax import custom_vjp, lax, random
from jax import custom_jvp, custom_vjp, lax, random
from jax import numpy as jnp

from flax.linen import initializers, module
Expand Down Expand Up @@ -120,6 +121,39 @@ def out_qdq_bwd(compute_dtype, res, g):
out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)


@partial(custom_jvp, nondiff_argnums=(2, 3, 4))
def dot_general_with_precision(
lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None
):
if precision != None or preferred_element_type != None:
warnings.warn(
'The function dot_general_with_precision will set the '
'precision/preferred_element_type and disregard any provided '
'values.'
)
return lax.dot_general(
lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT
)


@dot_general_with_precision.defjvp
def dot_general_with_precision_jvp(
dimension_numbers, precision, preferred_element_type, primals, tangents
):
lhs, rhs = primals
lhs_dot, rhs_dot = tangents

out = lax.dot_general(
lhs, rhs, dimension_numbers, precision=lax.Precision.DEFAULT
)
grad_out = lax.dot_general(
lhs_dot, rhs, dimension_numbers, precision=lax.Precision.HIGHEST
) + lax.dot_general(
lhs, rhs_dot, dimension_numbers, precision=lax.Precision.HIGHEST
)
return out, grad_out


class Fp8DotGeneralOp(module.Module):
amax_history_length: int = 1024

Expand Down Expand Up @@ -162,7 +196,6 @@ def __call__(self, *args, **kwargs):
x = args[0]
k = args[1]
dimension_numbers = args[2]
precision = kwargs['precision']

# Use the `k.dtype` since it aligns with the `dtype` of its layers,
# namely, the computation data type.
Expand All @@ -175,7 +208,7 @@ def __call__(self, *args, **kwargs):
k_qdq = in_qdq(
comp_dtype, k, self.kernel_scale.value, self.kernel_amax_history.value
)
y_qdq = lax.dot_general(x_qdq, k_qdq, dimension_numbers, precision) # type: ignore
y_qdq = dot_general_with_precision(x_qdq, k_qdq, dimension_numbers) # type: ignore
y = out_qdq(
comp_dtype,
y_qdq,
Expand Down
Loading