-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Feat] support fp8 quantization in update weights #24488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feat] support fp8 quantization in update weights #24488
Conversation
5dee0fe to
2f90b31
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors weight processing for FP8 quantization to support weight updates, primarily by introducing a _wrap_parameter_or_copy helper function. This is a good change for compatibility with CUDA graphs. The change in kv_cache.py also improves robustness by ensuring quantization scales are always present. However, I've found a critical issue in Fp8MoEMethod.process_weights_after_loading where a parameter is not correctly unwrapped, leading to a no-op update and incorrect behavior in certain code paths. I've also suggested an improvement in kv_cache.py to make the code more robust by removing some overly strict assertions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the else branch of the conditional starting at line 733, the variables w2_weight and w2_weight_scale_inv are assigned torch.nn.Parameter objects on lines 755-756, instead of their underlying tensor data. Consequently, these calls to _wrap_parameter_or_copy become no-ops due to self-copying, which is likely not the intended behavior and can lead to incorrect weight updates.
This is inconsistent with how w13_weight is handled in the same block, which correctly uses .data. To fix this, you should modify lines 755-756 to extract the tensor data, like so:
# In vllm/model_executor/layers/quantization/fp8.py, lines 755-756
w2_weight = layer.w2_weight.data
w2_weight_scale_inv = layer.w2_weight_scale_inv.dataSince the fix is outside the diff, I'm placing this comment here to highlight this critical issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These assertions could make the code brittle. If another part of the codebase modifies these attributes partially (e.g., removes q_scale but not k_scale), these assertions will fail. The main goal here is to ensure all weights are present if any are missing. Simply checking for q_scale and then creating all weights is sufficient and more robust against unforeseen state changes.
…dd missing scale attributes Signed-off-by: huangweixiao <huangweixiao@msh.team>
2f90b31 to
cb70707
Compare
|
Update? |
This MR makes that
process_weights_after_loadingcould be reused in fp8 quantization