-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
DP/EP Support for gpt-oss with deepep-ht comm kernel on SM100 #23608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds Data Parallelism and Expert Parallelism support for GPT-OSS models, particularly with the deepep-ht communication kernel and mxfp4 quantization. The changes introduce a new kernel wrapper, TrtLlmGenExperts, to leverage flashinfer's trtllm_fp4_block_scale_routed_moe kernel. While the overall structure for integrating this new path seems correct, I've identified two critical issues in the new TrtLlmGenExperts implementation that will prevent expert parallelism from functioning correctly. These issues need to be addressed to ensure the correctness and functionality of the feature.
| return True | ||
|
|
||
| def supports_expert_map(self) -> bool: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The supports_expert_map method should return True. Currently, it returns False, which will cause a ValueError in FusedMoEModularKernel when expert parallelism is enabled (ep_size > 1), as an expert_map will be provided. This prevents the expert parallelism feature from working.
| return False | |
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be True .
| def apply( | ||
| self, | ||
| output: torch.Tensor, | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| activation: str, | ||
| global_num_experts: int, | ||
| expert_map: Optional[torch.Tensor], | ||
| w1_scale: Optional[torch.Tensor], | ||
| w2_scale: Optional[torch.Tensor], | ||
| w1_zp: Optional[torch.Tensor], | ||
| w2_zp: Optional[torch.Tensor], | ||
| a1q_scale: Optional[torch.Tensor], | ||
| a2_scale: Optional[torch.Tensor], | ||
| workspace13: torch.Tensor, | ||
| workspace2: torch.Tensor, | ||
| expert_tokens_meta: Optional[mk.ExpertTokensMetadata], | ||
| apply_router_weight_on_input: bool, | ||
| ): | ||
| topk = topk_ids.size(-1) | ||
| local_num_experts = w1.size(0) | ||
| intermediate_size = w2.size(1) | ||
| local_expert_offset = self.moe.ep_rank * local_num_experts | ||
|
|
||
| x_quant = hidden_states | ||
| x_scale = a1q_scale | ||
| if x_scale is not None: | ||
| x_scale = x_scale.view(torch.float8_e4m3fn).reshape( | ||
| *x_quant.shape[:-1], -1) | ||
|
|
||
| packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( | ||
| torch.bfloat16).view(torch.int16) | ||
|
|
||
| assert w1_scale is not None | ||
| assert w2_scale is not None | ||
| kwargs = { | ||
| "topk_ids": | ||
| packed_tensor, | ||
| "routing_bias": | ||
| None, | ||
| "hidden_states": | ||
| x_quant, | ||
| "hidden_states_scale": | ||
| x_scale, | ||
| "gemm1_weights": | ||
| w1, | ||
| "gemm1_weights_scale": | ||
| w1_scale, | ||
| "gemm1_bias": | ||
| self.w13_bias, | ||
| "gemm1_alpha": | ||
| self.gemm1_alpha, | ||
| "gemm1_beta": | ||
| self.gemm1_beta, | ||
| "gemm1_clamp_limit": | ||
| self.gemm1_clamp_limit, | ||
| "gemm2_weights": | ||
| w2, | ||
| "gemm2_weights_scale": | ||
| w2_scale, | ||
| "gemm2_bias": | ||
| self.w2_bias, | ||
| "output1_scale_scalar": | ||
| None, | ||
| "output1_scale_gate_scalar": | ||
| None, | ||
| "output2_scale_scalar": | ||
| None, | ||
| "num_experts": | ||
| global_num_experts, | ||
| "top_k": | ||
| topk, | ||
| "n_group": | ||
| None, | ||
| "topk_group": | ||
| None, | ||
| "intermediate_size": | ||
| intermediate_size, | ||
| "local_expert_offset": | ||
| local_expert_offset, | ||
| "local_num_experts": | ||
| local_num_experts, | ||
| "routed_scaling_factor": | ||
| None, | ||
| "tile_tokens_dim": | ||
| self._get_tile_tokens_dim(x_quant, topk, local_num_experts), | ||
| "routing_method_type": | ||
| 1, | ||
| "do_finalize": | ||
| True, | ||
| "output": | ||
| output, | ||
| } | ||
|
|
||
| from flashinfer import trtllm_fp4_block_scale_routed_moe | ||
| trtllm_fp4_block_scale_routed_moe(**kwargs) | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a mismatch in the expert ID format. The topk_ids received by this method are local expert IDs, as they are processed by a prepare_finalize implementation (e.g., DeepEPHTPrepareAndFinalize) which maps global IDs to local ones. However, the trtllm_fp4_block_scale_routed_moe kernel seems to expect global expert IDs, especially given the presence of the local_expert_offset parameter, which is typically used to identify the range of global expert IDs managed by the current rank. Passing local IDs to a kernel expecting global IDs will result in incorrect expert selection and computation. Please ensure the kernel receives expert IDs in the expected format.
|
Thanks @zyongye ! LGTM ! |
| # Note: init_prepare_finalize should only be called by | ||
| # prepare_communication_buffer_for_model. | ||
| def init_prepare_finalize(self): | ||
| def init_prepare_finalize(self, layer: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The layer should have torch.nn.Module type.
| self, | ||
| prepare_finalize: FusedMoEPrepareAndFinalize, | ||
| moe: FusedMoEConfig, | ||
| layer: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| prepare_finalize: FusedMoEPrepareAndFinalize, | ||
| # TODO(bnell): Remove. Every layer should have an moe config object. | ||
| moe: FusedMoEConfig, | ||
| layer: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
|
||
| class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): | ||
|
|
||
| def __init__(self, moe: FusedMoEConfig, layer: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you pass the individual components here instead of the entire layer? Using the layer makes it harder to test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| ) -> mk.FusedMoEPermuteExpertsUnpermute: | ||
| if (prepare_finalize.activation_format == | ||
| mk.FusedMoEActivationFormat.BatchedExperts): | ||
| raise NotImplementedError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want a graceful fallback instead of an error you could overload maybe_make_prepare_finalize and make it return None for the unhandled cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what can we fall back to? I don't know if there's any other kernel has batched mxfp4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback would basically be disabling the all2all communication for this layer and using the non-batched kernels but maybe erroring out would be better.
|
@bnellnm I updated the change. Could you please take a look? |
0ff3447 to
1df5aa2
Compare
| self, | ||
| prepare_finalize: mk.FusedMoEPrepareAndFinalize, | ||
| moe: FusedMoEConfig, | ||
| layer: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: all these layer args should be torch.nn.Module also.
bnellnm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. All the layer arguments could be torch.nn.Module instead of Any
|
Done with the nit fix. I will run some performance benchmarks tonight after the B200 is freed. |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
| gemm1_alpha, | ||
| gemm1_beta, | ||
| gemm1_clamp_limit, | ||
| w13_bias, | ||
| w2_bias, | ||
| max_capture_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing types
|
In my test update: I found the code said DeepEP high-throughput kernels are not CUDA Graph compatible but low-latency kernels are. However low-latency kernel need batchedExperts which is not supported by mxfp4 moe. So for now naive comm kernel is best for me, and #23964 may be better after it merged. |
Rebased and cleaned up version for #22907 since @varun-sundar-rabindranath is out.
Model: 120B
Benchmark on Random Datset for DEP 4
python benchmark_serving.py --model "openai/gpt-oss-120b" --dataset-name random --ignore-eos --num-prompts 2048 --random-input-len 1000 --random-output-len 1000 --port 8000 --backend vllmHow to run
Need this PR from flashinfer, as well as the kernel fix to work correctly. (cc @IwakuraRein )
Edit:
This PR is working on ToT Flashinfer