Skip to content

Conversation

@MengqingCao
Copy link
Collaborator

@MengqingCao MengqingCao commented Apr 10, 2025

What this PR does / why we need it?

Backport: #252
This support speculative decoding in Ascend, including speculating with a draft model、by matching n-grams in the prompt、using MLP speculators and using EAGLE based draft models.

Backport: #423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support run the draft_model_runner with multi-step prepare on the NPU directly and support draft_model_runner use MLA.

  1. before this pr, MultiStepWorker would not step into the branch using NPU prepare, but only into the branch using CPU prepare (line 52 of vllm_ascend/patch/patch_multi_step_worker.py). Although this has no effect on the correct operation of speculative decoding and the performance of the two branches is basically the same as of the current version, I support entering this branch in this PR. In general, there are two main changes in patch_multi_step_worker.py: first, the is_cuda_like() check is removed and the TP1DraftModelRunner rewritten in vllm_ascend is used; second, the supports_gpu_multi_step() function is made to return true on NPU devices when outer Multi_step_worker could work correct.

  2. before this pr, TP1DraftModelRunner only supports Attention on NPU, but not MLA. The relevant adaptation is in vllm_ascend/worker/draft_model_runner.py. Although I don’t know why the input_positions of model_input.attn_metadata in vllm-ascend needs to be added in execute_model, it is done in model_runner.py, so I also made corresponding changes. Otherwise, when atten_backend is MLA, it will prompt that input_positions cannot be found.

  3. I commented out two lines in draft_model_runner.py in line118 to support the scenario of K>1.

# lora_mapping=model_input.lora_mapping,
# lora_requests=model_input.lora_requests,

I added comments. In the future, when vllm-ascend supports lora feature, the changes here can be restored.

TODO:

  • revert the patch when the related issues are addressed in vllm

How was this patch tested?

CI passed with new added test.

  • e2e test for medusa proposer: tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
  • e2e test for mlp proposer: tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
  • e2e test for n-gram proposer: tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:

  • tests/singlecard/spec_decode/test_dynamic_spec_decode.py
  • tests/singlecard/spec_decode/test_multi_step_worker.py
  • tests/singlecard/spec_decode/test_ngram_worker.py
  • tests/singlecard/spec_decode/test_spec_decode_worker.py

@Yikun Yikun changed the title [SpecDecode] support spec decode in main [SpecDecode] Add spec decode support Apr 12, 2025
@MengqingCao MengqingCao force-pushed the patch_spec branch 4 times, most recently from 911f9ce to d9676da Compare April 15, 2025 02:13
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
@MengqingCao
Copy link
Collaborator Author

There are 3 workflow cancled since a higher priority waiting request for limit-npu-4 exists. I think CI passed on (linux-arm64-npu-1, v0.8.4) is enough. WDYT? @wangxiyuan

# See the License for the specific language governing permissions and
# limitations under the License.
#

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# What's Patched and how it works:
# File: patch_xxx.py
#   1. vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output
#    why:
#    how:
#    related pr(if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
#    future plane:
#   2. xxxx
# File: patch_yyy.py


# 0.8.4 patch doc:
# platform-0.8.4 + platform-common + worker-0.8.4 + worker-common
# ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

CI in linux-arm64-npu-4, v0.8.4 failed due to the following reason:

Error: Error: failed to run script step: command terminated with non-zero exit code: Error executing in Docker Container: 137
Error: Process completed with exit code 1.
Error: Executing the custom container implementation failed. Please contact your self hosted runner administrator.

It seems no related to this pr, and the other 3 workflow is green

Signed-off-by: MengqingCao <cmq0113@163.com>
@wangxiyuan
Copy link
Collaborator

merged. Only fixed the merge conflict for the newest commit. No need to wait CI again.

@wangxiyuan wangxiyuan merged commit 6ee7f5c into vllm-project:main Apr 17, 2025
13 of 14 checks passed
@MengqingCao MengqingCao deleted the patch_spec branch April 18, 2025 00:59
ttanzhiqiang pushed a commit to ttanzhiqiang/vllm-ascend that referenced this pull request Apr 27, 2025
### What this PR does / why we need it?
Backport: vllm-project#252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: vllm-project#423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
Backport: vllm-project#252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: vllm-project#423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants