-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Performance][B200] Fix deepgemm prologue #27897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,7 @@ | |
| is_deep_gemm_e8m0_used, | ||
| is_deep_gemm_supported, | ||
| should_use_deepgemm_for_fp8_linear, | ||
| transform_sf_into_required_layout, | ||
| ) | ||
| from vllm.utils.torch_utils import direct_register_custom_op | ||
|
|
||
|
|
@@ -929,6 +930,50 @@ def requant_weight_ue8m0_inplace( | |
| s_old.copy_(s_requant) | ||
|
|
||
|
|
||
| def deepgemm_post_process_fp8_weight_block( | ||
| wq: torch.Tensor, ws: torch.Tensor, quant_block_shape: tuple[int], use_e8m0: bool | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| assert wq.dtype == torch.float8_e4m3fn, ( | ||
| "Expected quantized tensor dtype " | ||
| f"to be torch.float8_e4m3fn, got {wq.dtype} instead." | ||
| ) | ||
| assert ws.dtype == torch.float32, ( | ||
| f"Expected tensor scales dtype to be torch.float32, got {ws.dtype} instead" | ||
| ) | ||
|
|
||
| if use_e8m0: | ||
| requant_weight_ue8m0_inplace(wq, ws, block_size=quant_block_shape) | ||
|
|
||
| original_ndim = wq.ndim | ||
| if wq.ndim == 2: | ||
| assert ws.ndim == 2 | ||
| wq = wq.unsqueeze(0) | ||
| ws = ws.unsqueeze(0) | ||
|
|
||
| # From https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/utils/layout.hpp#L46 | ||
| recipe = (1, 128, 128) | ||
|
|
||
| # Ref : https://github.com/deepseek-ai/DeepGEMM/blob/c9f8b34dcdacc20aa746b786f983492c51072870/csrc/apis/gemm.hpp | ||
| # DeepGemm uses the `transform_sf_into_required_layout` function to | ||
| # represent scales in the correct format. | ||
| dg_ws = transform_sf_into_required_layout( | ||
| sf=ws, | ||
| mn=wq.size(1), | ||
| k=wq.size(2), | ||
| recipe=recipe, | ||
| num_groups=wq.size(0), | ||
| # is the scale factors for A in (Refers to the argument A in A @ B). | ||
| # Weights are B. | ||
| is_sfa=False, | ||
| ) | ||
|
|
||
| if original_ndim == 2: | ||
| wq = wq.squeeze(0) | ||
| dg_ws = dg_ws.squeeze(0) | ||
|
|
||
| return wq, dg_ws | ||
|
|
||
|
|
||
| def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: | ||
| """Pad the weight tensor. This is an optimization on ROCm platform, which | ||
| can benefit from tensors located far enough from one another in memory""" | ||
|
|
@@ -1141,11 +1186,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): | |
| should_use_deepgemm = should_use_deepgemm_for_fp8_linear( | ||
| layer.orig_dtype, layer.weight | ||
| ) | ||
| if is_deep_gemm_e8m0_used() and should_use_deepgemm: | ||
| block_sz = tuple(layer.weight_block_size) | ||
| requant_weight_ue8m0_inplace( | ||
| layer.weight.data, layer.weight_scale.data, block_sz | ||
| if should_use_deepgemm: | ||
| dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( | ||
| wq=layer.weight.data, | ||
| ws=layer.weight_scale.data, | ||
| quant_block_shape=tuple(layer.weight_block_size), | ||
| use_e8m0=is_deep_gemm_e8m0_used(), | ||
| ) | ||
| layer.weight = torch.nn.Parameter(dg_weight, requires_grad=False) | ||
| layer.weight_scale = torch.nn.Parameter(dg_weight_scale, requires_grad=False) | ||
|
Comment on lines
+1196
to
+1197
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to preserve the attributes on the original parameter cc @kylesayrs
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, without the original attributes we won't be able to reload weights. More changes than this will be required to support reloading, so this is fine to land now and rebase later.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I just |
||
|
|
||
|
|
||
| def expert_weight_is_col_major(x: torch.Tensor) -> bool: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment doesn't match this line since "is" is ==
Also isn't it the case though that we still want to use UE8M0 on hopper for cases like DeepSeek terminus?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, actually we are using e8m0 for hopper currently, this seems a breaking change for me.
We should carefully test and benchmark before we use this.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this the state of
main:let
wsbe a weight scales tensor of shape[X, 4096]and datatypefloat32but keep the weight scales infloat32. i.e. each float32 value actually holds UE8M0 content. Look here. i.e. only the first byte of each float32 value will have the actual contents.[EDIT]The stricken out portion was wrong. We actually cast the weights to ue8m0 and then expand it back to float32 - effectively the scale values can be one of
{2^i where i in [-127, 127]}wswill be of shape[X, 4096]and of datatypefloat32.This PR:
UE8M0and then we use thetransform_sf_into_required_layout()from deepgemm to pack the scales into an int32 tensor. i.e.wswill be of shape[x, 1024]and of datatypeint32. Effectively the scale values can be one of{i where in [-127, 127]}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay so Blackwell just has the packing part specifically, understood