Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA] Fix FP8 QDQ calls in praxis #79

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,24 @@ def setup(self) -> None:
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
)
DISALLOW_BFLOAT16_CONVERSION = (
base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
)
scale_args = {
'shape': [1],
'init': base_layer.WeightInit.Constant(1.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT],
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
amax_history_args = {
'shape': [self.amax_history_length],
'init': base_layer.WeightInit.Constant(0.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT],
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
self.create_variable(
'input_amax_history', base_layer.WeightHParams(**amax_history_args)
Expand Down Expand Up @@ -98,19 +101,19 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
theta = self.theta

x_qdq = fp8_ops.in_qdq(
comp_dtype, x, theta.input_scale, theta.input_amax_history
comp_dtype, jnp.float8_e4m3fn, x, theta.input_scale,
theta.input_amax_history
)
k_qdq = fp8_ops.in_qdq(
comp_dtype, k, theta.kernel_scale, theta.kernel_amax_history
comp_dtype, jnp.float8_e4m3fn, k, theta.kernel_scale,
theta.kernel_amax_history
)
y_qdq = jnp.einsum(
equation, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision
)
y = fp8_ops.out_qdq(
comp_dtype,
y_qdq,
theta.output_grad_scale,
theta.output_grad_amax_history,
comp_dtype, jnp.float8_e5m2, y_qdq, theta.output_grad_scale,
theta.output_grad_amax_history
)

return y