Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Aug 14, 2025

Purpose

Integrate GPTOSS with DeepEPHTPrepareFinalize

Commands:
server command: VLLM_ALL2ALL_BACKEND="deepep_high_throughput" vllm serve openai/gpt-oss-20b --port 9010 --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching
lm_eval command: lm_eval --model local-completions --tasks gsm8k --model_args model=openai/gpt-oss-20b,base_url=http://127.0.0.1:9010/v1/completions,num_concurrent=30,max_retries=3 --limit 100

Issue: The server some times hangs / reports IMA. When the server runs through the lm_eval outputs are good. They look like

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.31|±  |0.0465|
|     |       |strict-match    |     5|exact_match|↑  | 0.24|±  |0.0429|

and match main TP.

Debugging:
This PR uses trtllm_fp4_block_scale_routed_moe from flashinfer. I narrowed the issue down to the flashinfer kernel.

  1. One issue is that link needs to be 256, otherwise it'll blow up at link as mPtrExpertCounts is not big enough.
    I am still debugging this.

Test Plan

Test Result

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Varun Sundar Rabindranath added 8 commits August 8, 2025 23:12
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com >
wip
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com >
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com >
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com >
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com >
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com >
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as draft August 14, 2025 14:16
@mergify mergify bot added the gpt-oss Related to GPT-OSS models label Aug 14, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates GPT-OSS with DeepEPHT, introducing a new path for Mixture-of-Experts (MoE) layers using flashinfer kernels, specifically for data and expert parallelism. It adds support for mxfp8 quantization and a new trtllm_moe layer. The changes are quite extensive. My review has identified some leftover debugging code and comments in vllm/model_executor/layers/quantization/mxfp4.py which should be removed before this pull request is merged. Given that the pull request description indicates debugging is still in progress, these findings serve as a reminder for cleanup.

Comment on lines +586 to +594
if False:
# TODO(varun) : remove before landing
return self._route_and_experts_example(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
topk_group, num_expert_group, global_num_experts, expert_map,
custom_routing_function, scoring_func, e_score_correction_bias,
apply_router_weight_on_input, activation, enable_eplb,
expert_load_view, logical_to_physical_map,
logical_replica_count)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block of code is disabled with if False: and contains a TODO to remove it before landing. This debugging code should be removed from the final version of the pull request.

Comment on lines +395 to +407
else:
#pass

if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# B200 code ??
# Quant config shouldn't be None !!
return TrtLlmGenExperts(moe)
else:
# H100 code ??
# you use matmul_ogs kernel here!
raise NotImplementedError(
"Mxfp4 does not support non-batched experts format for EP")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This else block contains leftover debugging comments and a #pass statement. These should be removed for production code to improve clarity and maintainability.

        else:
            if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
                    or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
                return TrtLlmGenExperts(moe)
            else:
                raise NotImplementedError(
                    "Mxfp4 does not support non-batched experts format for EP")

@zyongye
Copy link
Member

zyongye commented Aug 14, 2025

Also a side note, please use the evaluation strategy in the recipe instead of lm_eval for this model.

@mergify
Copy link

mergify bot commented Aug 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 14, 2025
logical_replica_count: Optional[torch.Tensor] = None
) -> torch.Tensor:

topk_weights, topk_ids = FusedMoE.select_experts(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topk_ids and topk_weights needs to be local and non local experts' id should be -1.

Or use global topk_ids and topk_weights, and provide local_expert_offset and local_num_experts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The topk_ids and topk_weights gets processed into,

use global topk_ids and topk_weights, and provide local_expert_offset and local_num_experts
in the all2alls. I verified that this is correct.

None,
"tile_tokens_dim":
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
"routing_method_type":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

routing_method_type is hardcoded to renormalize. Maybe add assertion above to make sure it's not using a different routing method.

@varun-sundar-rabindranath
Copy link
Contributor Author

Update on debugging:

  • I got VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 working reliably by explicitly adding the --enforce-eager option.
  • VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 still locks up. I have narrowed it down to mPermuteGemm1.run but afaict all the setup is correct and the permute indices inside is also computed correctly. looking into this further.
  • hunch: I think the interaction with torch compile might have something to do with this.

"MX-FP8 quantization. Please install it with" \
"`pip install flashinfer`") from err

return mxfp8_quantize(x, is_sf_swizzled_layout=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IwakuraRein is this the right way to quantize bf16 activations to fp8 ? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when you are using mxfp4 x mxfp8 path.

@mgoin mgoin changed the title [Kernel] DP + EP : GPTOSS + Deepep-HightThroughput [Kernel] DP + EP : GPTOSS + DeepEP-HighThroughput Aug 18, 2025
@IwakuraRein
Copy link
Contributor

  • hunch: I think the interaction with torch compile might have something to do with this.

If torch.compile is using garbage values to initialize routing_logits and/or topk_ids then there might be Illegal memory access.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models needs-rebase

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants