diff --git a/torchao/float8/config.py b/torchao/float8/config.py index eb28dcbd8e..3b6b05c0e8 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -132,6 +132,12 @@ class Float8LinearConfig: # configuration, this field may move to per-tensor configs. delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + # If True, fp8_weight will always be re-computed in backward. + # If False, fp8_weight from forward may be saved for backward. + # It's recommended to enable this flag when using FSDP. + # Otherwise, the entire fp8_weight, instead of the sharded weight may be saved. + force_recompute_fp8_weight_in_bwd: bool = False + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index cb0ff7afb0..acf3872b27 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -14,6 +14,8 @@ import torch +import torch.utils.checkpoint as checkpoint + from torchao.float8.config import Float8LinearConfig, ScalingType from torchao.float8.float8_scaling_utils import ( @@ -29,11 +31,17 @@ from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax +from torchao.float8.float8_utils import ( + e4m3_dtype, + e5m2_dtype, + tensor_to_amax, + tensor_to_scale, +) from torchao.float8.fsdp_utils import ( WeightWithDelayedFloat8CastTensor, @@ -53,8 +61,9 @@ class manual_float8_matmul(torch.autograd.Function): def forward( ctx, input_fp8, - weight_fp8_t, + weight_fp8, ): + weight_fp8_t = weight_fp8.t() ctx.save_for_backward(input_fp8, weight_fp8_t) # the reshapes are needed in order to make the shapes compatible with # torch.mm @@ -95,7 +104,7 @@ def backward(ctx, grad_output_fp8): input_fp8_reshaped, ) - return grad_input, grad_weight.t() + return grad_input, grad_weight class Float8Linear(torch.nn.Linear): @@ -180,6 +189,8 @@ def __init__(self, *args, **kwargs): # would be initialized in every iteration. self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward + self.force_recompute_fp8_weight_in_bwd = self.config.force_recompute_fp8_weight_in_bwd + def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len @@ -299,52 +310,38 @@ def cast_input_to_float8( return input_fp8 - def cast_weight_to_float8( - self, weight: torch.Tensor, is_amax_initialized: bool - ) -> torch.Tensor: + + def get_weight_scale(self, weight: torch.Tensor) -> torch.Tensor: if self.scaling_type_weight is ScalingType.DELAYED: - if isinstance(self.weight, Float8Tensor): # cast by FSDP - weight_fp8 = self.weight - else: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - weight, - self.fp8_amax_weight, - self.fp8_amax_history_weight, - self.fp8_scale_weight, - scale_fn_name, - e4m3_dtype, - is_amax_initialized, - reduce_amax=False, - ) - - weight_fp8 = hp_tensor_to_float8_delayed( - weight, - self.fp8_scale_weight, - e4m3_dtype, - self.fp8_amax_weight, - linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + weight, + self.fp8_amax_weight, + self.fp8_amax_history_weight, + self.fp8_scale_weight, + scale_fn_name, + e4m3_dtype, + self.is_amax_initialized, + reduce_amax=True, + ) + self.fp8_amax_weight.fill_(tensor_to_amax(weight)) + return self.fp8_scale_weight elif self.scaling_type_weight is ScalingType.DYNAMIC: - if isinstance(self.weight, Float8Tensor): # cast by FSDP - weight_fp8 = self.weight - else: - weight_fp8 = hp_tensor_to_float8_dynamic( - self.weight, - e4m3_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) + return tensor_to_scale(weight, e4m3_dtype) else: assert self.scaling_type_weight is ScalingType.STATIC - weight_fp8 = hp_tensor_to_float8_static( - self.weight, - self.fp8_static_scale_weight, - e4m3_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) + return self.fp8_static_scale_weight + + def cast_weight_to_float8( + self, weight: torch.Tensor, weight_scale: torch.Tensor, is_amax_initialized: bool, + ) -> torch.Tensor: + weight_fp8 = hp_tensor_and_scale_to_float8( + weight, + weight_scale, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) return weight_fp8 def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: @@ -390,15 +387,32 @@ def float8_post_forward(self): # amaxes and scales self.is_amax_initialized = True self.amax_and_scale_synced = False + + def cast_weight_and_matmul(self, input_fp8, weight_scale): + weight_fp8 = self.cast_weight_to_float8(self.weight, weight_scale, self.is_amax_initialized) + output = manual_float8_matmul.apply(input_fp8, weight_fp8) + return output def forward(self, input: torch.Tensor) -> torch.Tensor: if self.has_any_delayed_scaling: self.float8_pre_forward(input) input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) - weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - output = manual_float8_matmul.apply(input_fp8, weight_fp8.t()) + if isinstance(self.weight, Float8Tensor): # cast by FSDP + weight_fp8 = self.weight + if self.force_recompute_fp8_weight_in_bwd: + output = checkpoint.checkpoint(manual_float8_matmul.apply, input_fp8, weight_fp8) + else: + output = manual_float8_matmul.apply(input_fp8, weight_fp8) + else: + weight_scale = self.get_weight_scale(self.weight) + # We save weight_scale and only recompute weight_fp8 + if self.force_recompute_fp8_weight_in_bwd: + output = checkpoint.checkpoint(self.cast_weight_and_matmul, input_fp8, weight_scale) + else: + output = self.cast_weight_and_matmul(input_fp8, weight_scale) + # Cast grad_output to float8_e5m2 during backward output = self.cast_output_to_float8_in_bw(output)