-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Feat] support fp8 quantization in update weights #24488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,25 @@ def _is_col_major(x: torch.Tensor) -> bool: | |
| return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m | ||
|
|
||
|
|
||
| def _wrap_parameter_or_copy(layer: torch.nn.Module, name: str, | ||
| weight: torch.Tensor): | ||
| layer_weight = getattr(layer, name) | ||
| if isinstance(layer_weight, Parameter): | ||
| # If it is already a Parameter, we assume it is the right shape | ||
| # directly copy it from weight to keep pointer unchanged in CUDA Graph | ||
| layer_weight.copy_(weight) | ||
| else: | ||
| # torch.compile() cannot use Parameter subclasses. | ||
| # but these weights are already Parameter | ||
| # so this can be compatible with torch.compile | ||
| param = Parameter(weight, requires_grad=False) | ||
| if hasattr(layer_weight, "weight_loader"): | ||
| # keep the weight_loader attribute to make sure | ||
| # the weight can be loaded correctly in weight update | ||
| param.weight_loader = layer_weight.weight_loader | ||
| setattr(layer, name, param) | ||
|
|
||
|
|
||
| class Fp8Config(QuantizationConfig): | ||
| """Config class for FP8.""" | ||
|
|
||
|
|
@@ -387,10 +406,9 @@ def process_weights_after_loading(self, layer: Module) -> None: | |
|
|
||
| weight = self._maybe_pad_weight(weight) | ||
|
|
||
| # Torch.compile cannot use Parameter subclasses. | ||
| layer.weight = Parameter(weight, requires_grad=False) | ||
| layer.weight_scale_inv = Parameter(weight_scale_inv, | ||
| requires_grad=False) | ||
| _wrap_parameter_or_copy(layer, "weight", weight) | ||
| _wrap_parameter_or_copy(layer, "weight_scale_inv", | ||
| weight_scale_inv) | ||
|
|
||
| # If checkpoint not serialized fp8, quantize the weights. | ||
| elif not self.quant_config.is_checkpoint_fp8_serialized: | ||
|
|
@@ -740,13 +758,13 @@ def process_weights_after_loading(self, layer: Module) -> None: | |
| w2_weight = layer.w2_weight | ||
| w2_weight_scale_inv = layer.w2_weight_scale_inv | ||
|
|
||
| # torch.compile() cannot use Parameter subclasses. | ||
| layer.w13_weight = Parameter(w13_weight, requires_grad=False) | ||
| layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, | ||
| requires_grad=False) | ||
| layer.w2_weight = Parameter(w2_weight, requires_grad=False) | ||
| layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, | ||
| requires_grad=False) | ||
| _wrap_parameter_or_copy(layer, "w13_weight", w13_weight) | ||
| _wrap_parameter_or_copy(layer, "w13_weight_scale_inv", | ||
| w13_weight_scale_inv) | ||
| _wrap_parameter_or_copy(layer, "w2_weight", w2_weight) | ||
| _wrap_parameter_or_copy(layer, "w2_weight_scale_inv", | ||
| w2_weight_scale_inv) | ||
|
||
|
|
||
| if self.rocm_aiter_moe_enabled: | ||
| # reshaping weights is required for aiter moe kernel. | ||
| shuffled_w13, shuffled_w2 = shuffle_weights( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,13 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: | |
| f"{self.__class__.__name__}.apply should not be called.") | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| # update weights may miss these attributes, we create it if not present | ||
| if not hasattr(layer, "q_scale"): | ||
| assert not hasattr(layer, "k_scale") | ||
| assert not hasattr(layer, "v_scale") | ||
| assert not hasattr(layer, "prob_scale") | ||
|
||
| self.create_weights(layer) | ||
|
|
||
| # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 | ||
| # regardless whether the kv-scale is available in the checkpoint. | ||
| # No need to process kv scales after loading if we are going to | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of torch28 (and torch27 using this pr) torch compile supports parameter subclasses. Given this, all that should be required is that the newly (maybe padded) weight is updated, a new Parameter class need not be created.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. In fact, the second branch in this code never gets triggered. All that is needed is to clean up
fp8.pyfromparam = Parameter(...)statement that drop the weight loaders.