Skip to content

Conversation

@ApsarasX
Copy link
Collaborator

@ApsarasX ApsarasX commented Jun 23, 2025

What this PR does / why we need it?

Fix the shape of the npu_moe_init_routing input parameters to support aclgraph mode on qwen3-moe

In addition to this PR, resolving the gatherv3 error might be necessary. See related PR #1297 #1446

Thanks to @yiz-liu for providing the idea

Does this PR introduce any user-facing change?

No

How was this patch tested?

Tested on Qwen3-30B-A3B

Closes: #1368

@ApsarasX ApsarasX force-pushed the community-support-qwen3moe-aclgraph branch from 9fcfc91 to 61eafae Compare June 23, 2025 17:52
@ApsarasX ApsarasX changed the title [Bugfix] Support Qwen3-MOE TP on aclgraph mode [Bugfix] Support Qwen3-MOE on aclgraph mode Jun 23, 2025
@ApsarasX ApsarasX force-pushed the community-support-qwen3moe-aclgraph branch from 61eafae to 3da51df Compare June 23, 2025 17:58
@codecov
Copy link

codecov bot commented Jun 23, 2025

Codecov Report

❌ Patch coverage is 40.00000% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 52.34%. Comparing base (c30ddb8) to head (25f1182).
⚠️ Report is 613 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/ops/common_fused_moe.py 44.44% 5 Missing ⚠️
vllm_ascend/ops/fused_moe.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1381       +/-   ##
===========================================
+ Coverage   27.39%   52.34%   +24.95%     
===========================================
  Files          56       78       +22     
  Lines        6191     9641     +3450     
===========================================
+ Hits         1696     5047     +3351     
- Misses       4495     4594       +99     
Flag Coverage Δ
unittests 52.34% <40.00%> (+24.95%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yiz-liu
Copy link
Collaborator

yiz-liu commented Jun 24, 2025

@ApsarasX Have you evaluated various combinations of max_model_len and max_num_batched_tokens? I anticipate there may be corner cases, as the devil is in the details.

@leo-pony
Copy link
Collaborator

leo-pony commented Jun 25, 2025

@ApsarasX Thanks for your PR, I tried Qwen/Qwen3-30B-A3B on main branch, and Issue also exist on main branch.
Then i use this PR, issue solved, Qwen3-30B-A3B run okay, additional information as following:

Run mode:

vllm serve  Qwen/Qwen3-30B-A3B --tensor-parallel-size 4

Run partitial log:

Downloading Model from https://www.modelscope.cn to directory: /shared/cache/modelscope/models/Qwen/Qwen3-30B-A3B
2025-06-25 02:43:00,659 - modelscope - WARNING - Using branch: master as version is unstable, use with caution
INFO 06-25 02:43:00 [serving_completion.py:66] Using default completion sampling params from model: {'temperature': 0.6, 'top_k': 20, 'top_p': 0.95}
INFO 06-25 02:43:00 [api_server.py:1349] Starting vLLM API server 0 on http://0.0.0.0:8000
INFO 06-25 02:43:00 [launcher.py:29] Available routes are:
INFO 06-25 02:43:00 [launcher.py:37] Route: /openapi.json, Methods: GET, HEAD
INFO 06-25 02:43:00 [launcher.py:37] Route: /docs, Methods: GET, HEAD
INFO 06-25 02:43:00 [launcher.py:37] Route: /docs/oauth2-redirect, Methods: GET, HEAD
INFO 06-25 02:43:00 [launcher.py:37] Route: /redoc, Methods: GET, HEAD
INFO 06-25 02:43:00 [launcher.py:37] Route: /health, Methods: GET
INFO 06-25 02:43:00 [launcher.py:37] Route: /load, Methods: GET
INFO 06-25 02:43:00 [launcher.py:37] Route: /ping, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /ping, Methods: GET
INFO 06-25 02:43:00 [launcher.py:37] Route: /tokenize, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /detokenize, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /v1/models, Methods: GET
INFO 06-25 02:43:00 [launcher.py:37] Route: /version, Methods: GET
INFO 06-25 02:43:00 [launcher.py:37] Route: /v1/chat/completions, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /v1/completions, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /v1/embeddings, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /pooling, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /classify, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /score, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /v1/score, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /v1/audio/transcriptions, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /rerank, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /v1/rerank, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /v2/rerank, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /invocations, Methods: POST
INFO 06-25 02:43:00 [launcher.py:37] Route: /metrics, Methods: GET
INFO:     Started server process [172434]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO 06-25 02:57:47 [chat_utils.py:420] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
INFO 06-25 02:57:47 [logger.py:43] Received request chatcmpl-791dda955ef9418f8d6eca0adebe9df6: prompt: '<|im_start|>user\nGive me a short introduction to large language models.<|im_end|>\n<|im_start|>assistant\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=20, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=4096, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
INFO 06-25 02:57:47 [async_llm.py:271] Added request chatcmpl-791dda955ef9418f8d6eca0adebe9df6.
INFO 06-25 02:57:51 [loggers.py:118] Engine 000: Avg prompt throughput: 1.8 tokens/s, Avg generation throughput: 17.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.1%, Prefix cache hit rate: 0.0%
INFO:     127.0.0.1:47620 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO 06-25 02:58:01 [loggers.py:118] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 13.4 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
INFO 06-25 02:58:11 [loggers.py:118] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

Send request:

curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
  "model": "Qwen/Qwen3-30B-A3B",
  "messages": [
    {"role": "user", "content": "Give me a short introduction to large language models."}
  ],
  "temperature": 0.6,
  "top_p": 0.95,
  "top_k": 20,
  "max_tokens": 4096
}'

request return:

{"id":"chatcmpl-791dda955ef9418f8d6eca0adebe9df6","object":"chat.completion","created":1750820267,
"model":"Qwen/Qwen3-30B-A3B","choices":[{"index":0,"message":{"role":"assistant","reasoning_content":null,"content":
"<think>\nOkay, the user wants a short introduction to large language models. Let me start by defining what they are. I should mention that they're AI systems trained on vast text data. Then, I need to explain their capabilities, like generating text, answering questions, and translating. Maybe mention the underlying technology, like neural networks, specifically transformers. It's important to note the scale—huge parameters and data. Also, touch on applications in various fields such as customer service, content creation, and research. Don't forget to include some examples like GPT or BERT. But keep it concise. Wait, the user said \"short,\" so I need to be brief. Avoid technical jargon but still be informative. Maybe end with a note on their impact and ongoing development. Let me check if I covered all key points without getting too detailed. Yeah, that should work.\n</think>\n\n
Large language models (LLMs) are advanced AI systems trained on vast amounts of text data to understand and generate human-like language. They use deep learning techniques, particularly transformer architectures, to process and analyze patterns in text. LLMs can perform tasks like answering questions, writing essays, coding, translating languages, and more. Their power comes from their scale—millions or billions of parameters—that allows them to capture complex linguistic structures and context. Examples include models like GPT, BERT, and others. While they offer significant benefits, they also raise ethical and technical challenges, such as bias, accuracy, and responsible use.",
"tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],
"usage":{"prompt_tokens":18,"total_tokens":324,"completion_tokens":306,"prompt_tokens_details":null},"prompt_logprobs":null,"kv_transfer_params":null}

This original error information for Qwen/Qwen3-30B-A3B as following, the same with #1368

(VllmWorker rank=2 pid=166267) Downloading Model from https://www.modelscope.cn to directory: /shared/cache/modelscope/models/Qwen/Qwen3-30B-A3B
Loading safetensors checkpoint shards: 100% Completed | 16/16 [04:17<00:00, 16.12s/it]
(VllmWorker rank=0 pid=166265)
(VllmWorker rank=0 pid=166265) INFO 06-25 02:18:15 [default_loader.py:272] Loading weights took 258.15 seconds
(VllmWorker rank=1 pid=166266) INFO 06-25 02:18:17 [model_runner_v1.py:1848] Loading model weights took 14.2466 GB
(VllmWorker rank=3 pid=166268) INFO 06-25 02:18:32 [backends.py:472] Dynamo bytecode transform time: 15.40 s
(VllmWorker rank=1 pid=166266) INFO 06-25 02:18:33 [backends.py:462] Using cache directory: /root/.cache/vllm/torch_compile_cache/0920c08a9a/rank_1_0 for vLLM's torch.compile
(VllmWorker rank=2 pid=166267) INFO 06-25 02:18:36 [backends.py:173] Compiling a graph for general shape takes 1.81 s
(VllmWorker rank=3 pid=166268) [rank3]:E0625 02:18:38.441000 166268 site-packages/torch/_guards.py:283] [0/0] Error while creating guard:
(VllmWorker rank=3 pid=166268) [rank3]:E0625 02:18:38.441000 166268 site-packages/torch/_guards.py:283] [0/0] Name: ''
(VllmWorker rank=3 pid=166268) [rank3]:E0625 02:18:38.447000 166268 site-packages/torch/_guards.py:285] [0/0]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 463, in init_ambient_guards
(VllmWorker rank=3 pid=166268) [rank3]:E0625 02:18:38.447000 166268 site-packages/torch/_guards.py:285] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527] WorkerProc hit an exception.
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527] Traceback (most recent call last):
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/vllm/v1/executor/multiproc_executor.py", line 522, in worker_busy_loop
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     output = func(*args, **kwargs)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/shared/mnj/vllm-ascend/vllm-ascend/vllm_ascend/worker/worker_v1.py", line 144, in determine_available_memory
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     self.model_runner.profile_run()
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/shared/mnj/vllm-ascend/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 1815, in profile_run
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     hidden_states = self._dummy_run(self.max_num_tokens)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return func(*args, **kwargs)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/shared/mnj/vllm-ascend/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 1774, in _dummy_run
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     hidden_states = model(
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return self._call_impl(*args, **kwargs)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return forward_call(*args, **kwargs)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/vllm/model_executor/models/qwen3_moe.py", line 519, in forward
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 239, in __call__
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     output = self.compiled_callable(*args, **kwargs)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return fn(*args, **kwargs)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return self._torchdynamo_orig_callable(
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return _compile(
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     guarded_code = compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return _compile_inner(code, one_graph, hooks, transform)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return function(*args, **kwargs)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 796, in _compile_inner
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     check_fn = CheckFunctionManager(
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 2261, in __init__
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     guard.create(builder)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_guards.py", line 281, in create
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     return self.create_fn(builder, self)
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/_dynamo/guards.py", line 1836, in SHAPE_ENV
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     guards = output_graph.shape_env.produce_guards(
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   File "/usr/local/python3.10.17/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4178, in produce_guards
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     raise ConstraintViolationError(
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['input_ids'].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]   - Not all values of RelaxedUnspecConstraint(L['input_ids'].size()[0]) are valid because L['input_ids'].size()[0] was inferred to be a constant (2048).
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527] You can suppress this exception and fall back to eager by setting:
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     import torch._dynamo
(VllmWorker rank=3 pid=166268) ERROR 06-25 02:18:38 [multiproc_executor.py:527]     torch._dynamo.config.suppress_errors = True

@Yikun
Copy link
Collaborator

Yikun commented Jun 25, 2025

@ApsarasX is it ready to merge? Any idea about #1368 (comment)?

@yiz-liu
Copy link
Collaborator

yiz-liu commented Jun 26, 2025

@ApsarasX is it ready to merge? Any idea about #1368 (comment)?

I don't think they are the same error, should be irrelevant.

@yiz-liu
Copy link
Collaborator

yiz-liu commented Jun 26, 2025

The num_scheduled_tokens value in _dummy_run does not always match max_num_batched_tokens, so this PR may fail in certain edge cases. However, since these scenarios are relatively rare, I recommend merging now to cover the majority of use cases—and I will address the remaining cases afterward.

@ApsarasX , could you please rebase and add me as a co-author? Thank you.

@ApsarasX ApsarasX force-pushed the community-support-qwen3moe-aclgraph branch from 3da51df to 477d5d1 Compare June 26, 2025 05:14
@ApsarasX
Copy link
Collaborator Author

The num_scheduled_tokens value in _dummy_run does not always match max_num_batched_tokens, so this PR may fail in certain edge cases. However, since these scenarios are relatively rare, I recommend merging now to cover the majority of use cases—and I will address the remaining cases afterward.

@ApsarasX , could you please rebase and add me as a co-author? Thank you.

The num_scheduled_tokens value in _dummy_run does not always match max_num_batched_tokens, so this PR may fail in certain edge cases. However, since these scenarios are relatively rare, I recommend merging now to cover the majority of use cases—and I will address the remaining cases afterward.

@ApsarasX , could you please rebase and add me as a co-author? Thank you.

I have added you as a co-author. Could you please handle these corner cases in the future.

@yiz-liu
Copy link
Collaborator

yiz-liu commented Jun 26, 2025

The num_scheduled_tokens value in _dummy_run does not always match max_num_batched_tokens, so this PR may fail in certain edge cases. However, since these scenarios are relatively rare, I recommend merging now to cover the majority of use cases—and I will address the remaining cases afterward.
@ApsarasX , could you please rebase and add me as a co-author? Thank you.

I have added you as a co-author. Could you please handle these corner cases in the future.

Yeah, already on my schedule.

@ApsarasX ApsarasX added the ready read for review label Jun 26, 2025
@ApsarasX
Copy link
Collaborator Author

@ApsarasX is it ready to merge? Any idea about #1368 (comment)?

PR ready, please merge

@ApsarasX
Copy link
Collaborator Author

@Yikun Please review

@wangxiyuan
Copy link
Collaborator

please add e2e test for qwen3-moe as well

@Yikun Yikun added the no-test label Jul 1, 2025
@Yikun
Copy link
Collaborator

Yikun commented Jul 3, 2025

You can add the model test on https://github.com/vllm-project/vllm-ascend/blob/main/tests/e2e/singlecard/test_aclgraph.py#L32

By running the reduce layer model: vllm-ascend/Qwen3-30B-A3B-Puring

ApsarasX and others added 2 commits July 6, 2025 10:36
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: ApsarasX <apsarax@outlook.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
@Yikun Yikun force-pushed the community-support-qwen3moe-aclgraph branch from 477d5d1 to 25f1182 Compare July 6, 2025 02:39
@Yikun
Copy link
Collaborator

Yikun commented Jul 6, 2025

Do a double confrim on:
#1631

And added a e2e test for qwen aclgraph case. LGTM

Thanks all @ApsarasX @yiz-liu @leo-pony @wangxiyuan

@Yikun Yikun merged commit c58accc into vllm-project:main Jul 6, 2025
20 checks passed
@Yikun Yikun mentioned this pull request Jul 8, 2025
45 tasks
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Oct 16, 2025
### What this PR does / why we need it?
Fix the shape of the `npu_moe_init_routing` input parameters to support
aclgraph mode on qwen3-moe

In addition to this PR, resolving the `gatherv3` error might be
necessary. See related PR
vllm-project#1297
vllm-project#1446

Thanks to @yiz-liu  for providing the idea

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tested on Qwen3-30B-A3B

Closes: vllm-project#1368

---------

Signed-off-by: ApsarasX <apsarax@outlook.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
Fix the shape of the `npu_moe_init_routing` input parameters to support
aclgraph mode on qwen3-moe

In addition to this PR, resolving the `gatherv3` error might be
necessary. See related PR
vllm-project#1297
vllm-project#1446

Thanks to @yiz-liu  for providing the idea

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tested on Qwen3-30B-A3B

Closes: vllm-project#1368

---------

Signed-off-by: ApsarasX <apsarax@outlook.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

5 participants