Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spec Decode] feat: support LoRA with speculative decoding #11966

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

llsj14
Copy link
Contributor

@llsj14 llsj14 commented Jan 12, 2025

Summary

  • This PR is for supporting LoRA with Speculative Decoding.

Implementation

  • There were two problems to solve to apply LoRA in Spec Decode.
  1. The LoRA adapter is mostly designed for the target model and might cause errors such as the following when the same LoRA adapter is applied to the draft model. Therefore, until the API interface is changed to inject the corresponding LoRA adapter for the draft model, the LoRA adapter should be only applied to the target model and temporarily disabled for the draft model.
ERROR 01-12 03:08:10 engine.py:135]     result = super().activate_adapter(lora_id)
ERROR 01-12 03:08:10 engine.py:135]   File "/vllm/vllm/lora/models.py", line 405, in activate_adapter
ERROR 01-12 03:08:10 engine.py:135]     module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
ERROR 01-12 03:08:10 engine.py:135]   File "/vllm/vllm/lora/layers.py", line 215, in set_lora
ERROR 01-12 03:08:10 engine.py:135]     self.lora_b_stacked[index,
ERROR 01-12 03:08:10 engine.py:135] RuntimeError: The size of tensor a (768) must match the size of tensor b (4096) at non-singleton dimension 0
  1. As shown in this Llama model implementation, when LoRA is enabled, the vocabulary size of both the target and draft models is further padded for kernel compatibility. Therefore, we need to adjust the vocabulary size of the SpecDecodeWorker. If this adjustment is not made, the following error occurs:
ERROR 01-16 13:26:52 engine.py:137]   File "/vllm/vllm/spec_decode/spec_decode_worker.py", line 779, in _run_speculative_decoding_step
ERROR 01-16 13:26:52 engine.py:137]     proposal_scores = self.scorer.score_proposals(
ERROR 01-16 13:26:52 engine.py:137]   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
ERROR 01-16 13:26:52 engine.py:137]     return func(*args, **kwds)
ERROR 01-16 13:26:52 engine.py:137]   File "/vllm/vllm/spec_decode/batch_expansion.py", line 86, in score_proposals
ERROR 01-16 13:26:52 engine.py:137]     contracted = self._contract_batch_all_spec(
ERROR 01-16 13:26:52 engine.py:137]   File "/vllm/vllm/spec_decode/batch_expansion.py", line 244, in _contract_batch_all_spec
ERROR 01-16 13:26:52 engine.py:137]     target_probs = target_sampler_output.sampled_token_probs.reshape(
ERROR 01-16 13:26:52 engine.py:137] RuntimeError: shape '[1, 4, 32000]' is invalid for input of size 129024

Test

  • Running Server
python -m vllm.entrypoints.openai.api_server \
     --model meta-llama/Llama-2-7b-hf \
     --port 8080 \
     --disable-custom-all-reduce \
     --swap-space 0 \
     --gpu-memory-utilization 0.9 \
     --enable-lora \
     --lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/ \
     --speculative_model JackFram/llama-68m \
     --num_speculative_tokens 3
  • Request w/ LoRA Adapter
curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "sql-lora",
        "prompt": "San Francisco is a",
        "max_tokens": 32,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq
  • Response
"text": " city in California, United States, on the tip of a peninsula between the Pacific Ocean and San Francisco Bay. San Francisco is a leading financial center and"
  • Request w/o LoRA adapter
curl http://localhost:8080/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "/mnt/lvm/checkpoints/hugginface/Llama-2-7b-hf",
        "prompt": "San Francisco is a",
        "max_tokens": 32,
                "top_k": 1,
                "top_p": 1.0,
        "temperature": 1.0
    }' | jq
  • Response
"text": " city of many neighborhoods, each with its own distinct personality. San Francisco is a city of many neighborhoods, each with its own distinct personality."

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@llsj14 llsj14 changed the title feat: support LoRA with speculative decoding [Spec Decode] feat: support LoRA with speculative decoding Jan 12, 2025
Copy link

mergify bot commented Jan 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @llsj14.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 16, 2025
@llsj14 llsj14 requested a review from njhill as a code owner January 16, 2025 13:34
@llsj14 llsj14 force-pushed the feat/spec-decode-lora branch from 0d00454 to 4224f0f Compare January 16, 2025 13:38
@mergify mergify bot removed the needs-rebase label Jan 16, 2025
@LiuXiaoxuanPKU LiuXiaoxuanPKU self-assigned this Jan 17, 2025
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Overall LGTM.

For the test, can you change the prompt so that it can pass and does not have any numerical issue?

@llsj14
Copy link
Contributor Author

llsj14 commented Jan 20, 2025

For the test, can you change the prompt so that it can pass and does not have any numerical issue?

Yes, I tried different prompts, but they produced slightly different results even with greedy decoding. I wanted to calculate LoRA using FP32 to address potential numerical issues, but it seems LoRA weights only support FP16 or BF16. I'm not sure if the differences are caused by my code, another problem, or numerical issues. I will investigate further.

@@ -367,7 +367,7 @@ def _create_single_target_seq_group_metadata(
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
lora_request=seq_group_metadata.lora_request,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LiuXiaoxuanPKU
I found the reason why the test failed.
the scorer with both batch expansion and mqa scoring set lora_request to None.
I fixed this and the test passed successfully.

Copy link
Contributor Author

@llsj14 llsj14 Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous test results, only the prefill stage utilized the LoRA request, while the decode stage did not apply the LoRA operation.

@@ -57,7 +57,7 @@ def score_proposals(
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
lora_request=seq_group_metadata.lora_request,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with batch expansion

del engine

# run speculative decoding with mqa scorer.
engine_args = EngineArgs(model=MODEL_PATH,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test specifically for MQA scoring.

@llsj14 llsj14 force-pushed the feat/spec-decode-lora branch from 7cf636f to aaca3a5 Compare January 26, 2025 05:13
@mergify mergify bot added the documentation Improvements or additions to documentation label Jan 26, 2025
Copy link

mergify bot commented Feb 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @llsj14.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 5, 2025
@llsj14 llsj14 force-pushed the feat/spec-decode-lora branch from e4c599a to 6a9c8a0 Compare February 6, 2025 15:30
@mergify mergify bot removed the needs-rebase label Feb 6, 2025
@llsj14
Copy link
Contributor Author

llsj14 commented Feb 6, 2025

@LiuXiaoxuanPKU @sroy745
I rebased my code and resolved conflicts. Could you check this PR when you have time?

@llsj14 llsj14 force-pushed the feat/spec-decode-lora branch from 1b75337 to 60811d4 Compare February 22, 2025 08:53
@llsj14
Copy link
Contributor Author

llsj14 commented Feb 22, 2025

test error log

[2025-02-22T11:46:06Z] metrics/test_metrics.py::test_async_engine_log_metrics_regression[True-4-half-distilbert/distilgpt2] ERROR 02-22 03:46:06 config.py:102] Error retrieving file list: 504 Server Error: Gateway Time-out for url: https://huggingface.co/api/models/distilbert/distilgpt2/tree/main?recursive=True&expand=False, retrying 1 of 2
--
  | [2025-02-22T11:46:18Z] ERROR 02-22 03:46:18 config.py:100] Error retrieving file list: 504 Server Error: Gateway Time-out for url: https://huggingface.co/api/models/distilbert/distilgpt2/tree/main?recursive=True&expand=False
  | [2025-02-22T11:46:18Z] FAILED

...

[2025-02-22T11:50:47Z] metrics/test_metrics.py:181:
--
  | [2025-02-22T11:50:47Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py:639: in from_engine_args
  | [2025-02-22T11:50:47Z]     engine_config = engine_args.create_engine_config(usage_context)
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/engine/arg_utils.py:1144: in create_engine_config
  | [2025-02-22T11:50:47Z]     model_config = self.create_model_config()
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/engine/arg_utils.py:1064: in create_model_config
  | [2025-02-22T11:50:47Z]     return ModelConfig(
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/config.py:314: in __init__
  | [2025-02-22T11:50:47Z]     hf_config = get_config(self.model, trust_remote_code, revision,
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/transformers_utils/config.py:256: in get_config
  | [2025-02-22T11:50:47Z]     if is_gguf or file_or_path_exists(
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/transformers_utils/config.py:179: in file_or_path_exists
  | [2025-02-22T11:50:47Z]     return file_exists(str(model),
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/transformers_utils/config.py:154: in file_exists
  | [2025-02-22T11:50:47Z]     file_list = list_repo_files(repo_id,
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/transformers_utils/config.py:143: in list_repo_files
  | [2025-02-22T11:50:47Z]     return with_retry(lookup_files, "Error retrieving file list")
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/transformers_utils/config.py:97: in with_retry
  | [2025-02-22T11:50:47Z]     return func()
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/vllm/transformers_utils/config.py:133: in lookup_files
  | [2025-02-22T11:50:47Z]     return hf_list_repo_files(repo_id,
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_validators.py:114: in _inner_fn
  | [2025-02-22T11:50:47Z]     return fn(*args, **kwargs)
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/huggingface_hub/hf_api.py:2945: in list_repo_files
  | [2025-02-22T11:50:47Z]     for f in self.list_repo_tree(
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/huggingface_hub/hf_api.py:3080: in list_repo_tree
  | [2025-02-22T11:50:47Z]     for path_info in paginate(path=tree_url, headers=headers, params={"recursive": recursive, "expand": expand}):
  | [2025-02-22T11:50:47Z] /usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_pagination.py:37: in paginate
  | [2025-02-22T11:50:47Z]     hf_raise_for_status(r)
  | [2025-02-22T11:50:47Z] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation speculative-decoding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants