Skip to content

Conversation

@mengwei805
Copy link
Collaborator

@mengwei805 mengwei805 commented Apr 23, 2025

What this PR does / why we need it?

As custom deepseek modeling do some changes to support graph mode in #585, so i follow it to change custom deepseek_mtp modeling.

And some modifications for k>1 were not carried over by the #429, now i add it.

In order to better take care of the MTP feature in the vllm-ascend repository, I added cases related to graph mode(torchair), but i skip it since torchair can not correctly clean up memory in vllmrunner.

Also i add some case for MTP quantization weights, but test weight is not ready, so i skip it and i will open it when test quant weights is ready.

#648 did not completely fix the sample change(#660) issue, I added the relevant changes.

Does this PR introduce any user-facing change?

now, u can use following method to use mtp in deepseek v3/r1 float or quant weights with eager mode.

llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    enforce_eager=True,
    trust_remote_code=True,
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)

or use mtp in deepseek v3/r1 float or quant weights with graph mode(torchair)

llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    trust_remote_code=True,
    additional_config={
        'enable_graph_mode': True,
    },
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)

add notes:

  1. now, we support k>1, so u can set num_speculative_tokens > 1 if there is sufficient redundant computing power;
  2. MTP is not supported in V1, we will support it when vLLM does it in [V1] Support DeepSeek MTP in V1 vllm#13500.
  3. if u run MTP failed by segmentation fault, u can follow v0.7.3 patch v0.7.3 Add MTP support for deepseek #236 file vllm_ascend/patch/patch_metrics.py method __npu_async_metrics_collector_init__

How was this patch tested?

local tested passed and test by CI

@mengwei805 mengwei805 force-pushed the main-mtp-graph branch 5 times, most recently from f1deef5 to 710cfdb Compare April 23, 2025 14:16
@mengwei805 mengwei805 force-pushed the main-mtp-graph branch 3 times, most recently from 2979fb2 to b8120a8 Compare April 24, 2025 15:04
@mengwei805 mengwei805 force-pushed the main-mtp-graph branch 20 times, most recently from 8ccfc54 to 7e34841 Compare April 27, 2025 10:36
@mengwei805 mengwei805 force-pushed the main-mtp-graph branch 4 times, most recently from 6dcb8ed to 5183289 Compare April 27, 2025 12:12
Signed-off-by: mengwei805 <mengwei25@huawei.com>
@ganyi1996ppo ganyi1996ppo merged commit 54c0e63 into vllm-project:main Apr 28, 2025
15 checks passed
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Oct 16, 2025
…llm-project#636)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?

As custom deepseek modeling do some changes to support graph mode in
vllm-project#585, so i follow it to
change custom deepseek_mtp modeling.

And some modifications for k>1 were not carried over by the
vllm-project#429, now i add it.

In order to better take care of the MTP feature in the vllm-ascend
repository, I added cases related to graph mode(torchair), but i skip it
since torchair can not correctly clean up memory in vllmrunner.

Also i add some case for MTP quantization weights, but test weight is
not ready, so i skip it and i will open it when test quant weights is
ready.

vllm-project#648 did not completely
fix the sample
change(vllm-project#660) issue, I
added the relevant changes.

### Does this PR introduce _any_ user-facing change?
now, u can use following method to use mtp in deepseek v3/r1 float or
quant weights with eager mode.
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    enforce_eager=True,
    trust_remote_code=True,
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

or use mtp in deepseek v3/r1 float or quant weights with graph
mode(torchair)
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    trust_remote_code=True,
    additional_config={
        'enable_graph_mode': True,
    },
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

add notes:
1. now, we support k>1, so u can set num_speculative_tokens > 1 if there
is sufficient redundant computing power;
2. MTP is not supported in V1, we will support it when vLLM does it in
vllm-project/vllm#13500.
3. if u run MTP failed by `segmentation fault`, u can follow v0.7.3
patch vllm-project#236 file
`vllm_ascend/patch/patch_metrics.py` method
`__npu_async_metrics_collector_init__`

### How was this patch tested?
local tested passed and test by CI

Signed-off-by: mengwei805 <mengwei25@huawei.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
…llm-project#636)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?

As custom deepseek modeling do some changes to support graph mode in
vllm-project#585, so i follow it to
change custom deepseek_mtp modeling.

And some modifications for k>1 were not carried over by the
vllm-project#429, now i add it.

In order to better take care of the MTP feature in the vllm-ascend
repository, I added cases related to graph mode(torchair), but i skip it
since torchair can not correctly clean up memory in vllmrunner.

Also i add some case for MTP quantization weights, but test weight is
not ready, so i skip it and i will open it when test quant weights is
ready.

vllm-project#648 did not completely
fix the sample
change(vllm-project#660) issue, I
added the relevant changes.

### Does this PR introduce _any_ user-facing change?
now, u can use following method to use mtp in deepseek v3/r1 float or
quant weights with eager mode.
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    enforce_eager=True,
    trust_remote_code=True,
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

or use mtp in deepseek v3/r1 float or quant weights with graph
mode(torchair)
```python
llm = LLM(
    model="wemaster/deepseek_mtp_main_random_bf16",
    tensor_parallel_size=2,
    speculative_config={
        "num_speculative_tokens": 1,
    },
    trust_remote_code=True,
    additional_config={
        'enable_graph_mode': True,
    },
    disable_log_stats=False,
    gpu_memory_utilization=0.8,
    max_model_len=64,
)
```

add notes:
1. now, we support k>1, so u can set num_speculative_tokens > 1 if there
is sufficient redundant computing power;
2. MTP is not supported in V1, we will support it when vLLM does it in
vllm-project/vllm#13500.
3. if u run MTP failed by `segmentation fault`, u can follow v0.7.3
patch vllm-project#236 file
`vllm_ascend/patch/patch_metrics.py` method
`__npu_async_metrics_collector_init__`

### How was this patch tested?
local tested passed and test by CI

Signed-off-by: mengwei805 <mengwei25@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants