Skip to content

Commit

Permalink
Merge pull request #79 from kaixih:fix_fp8_qdq
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651130032
  • Loading branch information
pax authors committed Jul 10, 2024
2 parents c41477c + b73b416 commit dafadc1
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions praxis/layers/injection/fp8_nvidia_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,24 @@ def setup(self) -> None:
OVERWRITE_WITH_GRADIENT = (
base_layer.WeightHParamsCollection.OVERWRITE_WITH_GRADIENT
)
DISALLOW_BFLOAT16_CONVERSION = (
base_layer.WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION
)
scale_args = {
'shape': [1],
'init': base_layer.WeightInit.Constant(1.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT],
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
amax_history_args = {
'shape': [self.amax_history_length],
'init': base_layer.WeightInit.Constant(0.0),
'dtype': jnp.float32,
'mesh_shape': self.mesh_shape,
'tensor_split_dims_mapping': None,
'collections': [OVERWRITE_WITH_GRADIENT],
'collections': [OVERWRITE_WITH_GRADIENT, DISALLOW_BFLOAT16_CONVERSION],
}
self.create_variable(
'input_amax_history', base_layer.WeightHParams(**amax_history_args)
Expand Down Expand Up @@ -98,16 +101,25 @@ def __call__(self, equation: str, *args: pytypes.JTensor) -> pytypes.JTensor:
theta = self.theta

x_qdq = fp8_ops.in_qdq(
comp_dtype, x, theta.input_scale, theta.input_amax_history
comp_dtype,
jnp.float8_e4m3fn,
x,
theta.input_scale,
theta.input_amax_history,
)
k_qdq = fp8_ops.in_qdq(
comp_dtype, k, theta.kernel_scale, theta.kernel_amax_history
comp_dtype,
jnp.float8_e4m3fn,
k,
theta.kernel_scale,
theta.kernel_amax_history,
)
y_qdq = jnp.einsum(
equation, x_qdq, k_qdq, _dot_general=fp8_ops.dot_general_with_precision
)
y = fp8_ops.out_qdq(
comp_dtype,
jnp.float8_e5m2,
y_qdq,
theta.output_grad_scale,
theta.output_grad_amax_history,
Expand Down

0 comments on commit dafadc1

Please sign in to comment.