Skip to content

Conversation

@mengwei805
Copy link
Collaborator

@mengwei805 mengwei805 commented Mar 28, 2025

What this PR does / why we need it?

spec decode MultiStepWorker support TP1DraftModelRunner fully, support run the draft_model_runner with multi-step prepare on the NPU directly and support draft_model_runner use MLA.

  1. before this pr, MultiStepWorker would not step into the branch using NPU prepare, but only into the branch using CPU prepare (line 52 of vllm_ascend/patch/patch_multi_step_worker.py). Although this has no effect on the correct operation of speculative decoding and the performance of the two branches is basically the same as of the current version, I support entering this branch in this PR. In general, there are two main changes in patch_multi_step_worker.py: first, the is_cuda_like() check is removed and the TP1DraftModelRunner rewritten in vllm_ascend is used; second, the supports_gpu_multi_step() function is made to return true on NPU devices when outer Multi_step_worker could work correct.

  2. before this pr, TP1DraftModelRunner only supports Attention on NPU, but not MLA. The relevant adaptation is in vllm_ascend/worker/draft_model_runner.py. Although I don’t know why the input_positions of model_input.attn_metadata in vllm-ascend needs to be added in execute_model, it is done in model_runner.py, so I also made corresponding changes. Otherwise, when atten_backend is MLA, it will prompt that input_positions cannot be found.

  3. I commented out two lines in draft_model_runner.py in line118 to support the scenario of K>1.

# lora_mapping=model_input.lora_mapping,
# lora_requests=model_input.lora_requests,

I added comments. In the future, when vllm-ascend supports lora feature, the changes here can be restored.

Does this PR introduce any user-facing change?

This PR has no effect on users.

How was this patch tested?

have tested

@mengwei805 mengwei805 force-pushed the v0.7.3-tp1draftmodelrunner branch 2 times, most recently from ef72b5d to bace95c Compare March 28, 2025 09:43
@mengwei805 mengwei805 changed the title 【v0.7.3】spec decode MultiStepWorker support TP1DraftModelRunner fully [v0.7.3]spec decode MultiStepWorker support TP1DraftModelRunner fully Mar 28, 2025
@wangxiyuan
Copy link
Collaborator

can you describe what the patch_multi_step_worker.py is changed in commit message? Thanks.

@MengqingCao
Copy link
Collaborator

can you describe what the patch_multi_step_worker.py is changed in commit message? Thanks.

@mengwei805 LGTM after this is added, I think adding related comments and updating pr message will make it clear, thx!

@mengwei805 mengwei805 force-pushed the v0.7.3-tp1draftmodelrunner branch from bace95c to 6f7b978 Compare March 31, 2025 11:45
@mengwei805
Copy link
Collaborator Author

mengwei805 commented Mar 31, 2025

can you describe what the patch_multi_step_worker.py is changed in commit message? Thanks.

Of course, I'm sorry I was so brief before. I have added relevant information in PR msg and commit msg.
In addition, for the modification of line 162 in vllm_ascend/worker/draft_model_runner.py in this PR,

# TODO: Add support for ASCEND when outer multi_step_worker
# could work correct.
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
    return False

I suggest that vllm-ascend's atten_backend can distinguish between ordinary Attention and MLA to facilitate the subsequent expansion of related functions. Currently, get_name() of both types of atten returns ASCEND

@mengwei805 mengwei805 force-pushed the v0.7.3-tp1draftmodelrunner branch from 6f7b978 to e835d88 Compare March 31, 2025 14:23
@mengwei805
Copy link
Collaborator Author

can you describe what the patch_multi_step_worker.py is changed in commit message? Thanks.

@mengwei805 LGTM after this is added, I think adding related comments and updating pr message will make it clear, thx!

Of course, I'm sorry I was so brief before. I have added relevant information in PR msg and commit msg.

@mengwei805 mengwei805 force-pushed the v0.7.3-tp1draftmodelrunner branch from e835d88 to 5632984 Compare April 1, 2025 03:18
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

Signed-off-by: mengwei805 <mengwei25@huawei.com>
@mengwei805 mengwei805 force-pushed the v0.7.3-tp1draftmodelrunner branch from 5632984 to 958a701 Compare April 1, 2025 03:23
@MengqingCao
Copy link
Collaborator

lgtm now, thx

@wangxiyuan
Copy link
Collaborator

It's very clear. Thanks

@wangxiyuan wangxiyuan merged commit fd9494f into vllm-project:v0.7.3-dev Apr 1, 2025
13 checks passed
wangxiyuan pushed a commit that referenced this pull request Apr 17, 2025
### What this PR does / why we need it?
Backport: #252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: #423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
ttanzhiqiang pushed a commit to ttanzhiqiang/vllm-ascend that referenced this pull request Apr 27, 2025
### What this PR does / why we need it?
Backport: vllm-project#252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: vllm-project#423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
Backport: vllm-project#252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: vllm-project#423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants