Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Mar 19, 2025

In V0 LoRA, we skipped the LoRA kernel launch when all the scheduled requests target the base model.
PR #14685 removed this optimization.
This PR re-introduces this optimizations such that it works for both V0 and V1.

Previously, i.e. before #14685 , on V0, we had a boolean flag, no_lora, in punica_gpu.py that tracked if any of the input requests needed LoRA. This works well for V0. But the V1 architecture uses the traced torch.compile graphs to execute a forward pass. This tracing, doesn't play well with dynamic control flow. However, the tracing treats torch.ops functions as a black box.

This PR, moves the flag no_lora flag inside the lora_expand and lora_shrink torch operations and triggers an early exit from the operation.

Benchmarks:
server command :

 VLLM_USE_V1=0  vllm serve meta-llama/Llama-2-7b-chat-hf  --gpu-memory-utilization 0.95  --enable-lora --max-loras 3 --max-cpu-loras 15 --max-lora-rank 64 --lora-modules lora=xtuner/Llama-2-7b-qlora-moss-003-sft --port 9001  --no-enable-prefix-caching

client command :

python3 benchmarks/benchmark_serving.py --model meta-llama/Llama-2-7b-chat-hf --dataset-name random --random-input-len 2048 --random-output-len 512 --request-rate inf --seed ${i} --port 9001 --num-prompts ${num_prompts}

Numbers:
The benchmark_serving.py command was run 4 times for every num_prompts values. All mean_ttft_ms measurements are provided below,
<style type="text/css"></style>

Machine : 1xA100      
mean_ttft_ms values      
branch \ Num prompts 1 4 8
main (400d483) - has #14685 148, 170, 170, 176 430, 658, 434, 437 750, 747, 762, 737
This PR 155, 165, 162, 170 410, 922, 400, 407 713, 687, 799, 724

Thanks @jeejeelee for flagging this 🙌

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor Author

@varun-sundar-rabindranath varun-sundar-rabindranath Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to this file :
Previously,

lora_mapping = LoRAMapping(
                            **dict(index_mapping=[0] * batch_size,
                                   prompt_mapping=[0] * batch_size,
                                   is_prefill=False))
                        self.set_active_loras(set(), lora_mapping)

this was sufficient for V0, to capture cudagraphs with the LoRA kernels. This is because of the is_prefill flag. When this is False, i.e. decode case, V0 always chooses to run the LoRA kernels. Therefore all captured CUDAGraphs record the LoRA kernels.

What changed ?
punica_gpu.py is now updated to do the handle to "no lora" case based on the LoRAMapping::index_mapping and the is_prefill flag is ignored. Due to this change, an index_mapping of [0] * batch_size simply translates to the no_lora_cpu_flag being set to True and the captured graphs don't include the LoRA kernels.

Fix : we explicitly add and remove the LoRA adapters (similar to what we do during profile runs)

Alternative solution/hack : In punica_gpu.py when setting the no_lora_flag_cpu, we could do,

if envs.VLLM_USE_V0:
  use_cuda_graphs = not is_prefill
  no_lora_flag_cpu = torch.all(token_lora_mapping == -1) and not use_cuda_graphs 

but ^ seems hacky and I'd like to avoid checks V0/V1 checks.

I prefer the implemented fix where we add and remove the LoRA adapters explicitly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed explanation, I like this.

@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review March 19, 2025 22:46
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnellnm @ProExpertProg @youkaichao can you take a look when you find some time please. Thanks 🙌

Copy link
Member

@youkaichao youkaichao Mar 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a good summarization. we leverage the point number 3 to deal with complicated attention operations, and it can be used for lora, too.

but if we have 2 code path for lora and no-lora, would it break cudagraph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For cudagraphs we always capture with LoRA. so there is just 1 path in that case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto about item() alternatives.

Varun Sundar Rabindranath added 7 commits March 24, 2025 16:53
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
lora_path='/not/a/real/path')
self.lora_manager.add_dummy_lora(dummy_lora_request,
LORA_WARMUP_RANK)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the warmup, we added some dummy LoRAs, then removed them. Perhaps we could continue using those dummy LoRAs and remove them after the capture is complete. I think this would reduce redundant code. See : https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py#L324

Copy link
Contributor Author

@varun-sundar-rabindranath varun-sundar-rabindranath Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GPUModelRunnerBase::_dummy_run is the one adding and removing the dummy LoRAs inside it. This _dummy_run function is used in 2 places,

  1. GPUModelRunnerBase::profile_run
  2. Worker::_warm_up_model

Updating _dummy_run to act differently based on the caller seems cumbersome.
But, I agree with your point on redundant code. I have refactored adding and removing of the dummy loras in this commit c05763e Please take a look.
Note : I considered using a context manager, but it looks like we already have a lot of indentation in that code and I didn't want to add another level.
What do you think ?

Varun Sundar Rabindranath added 2 commits March 25, 2025 16:21
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 26, 2025
@DarkLight1337 DarkLight1337 merged commit 6c663df into vllm-project:main Mar 26, 2025
51 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants