Skip to content

Commit

Permalink
[BugFix] fix wrong output when using lora and num_scheduler_steps=8 (v…
Browse files Browse the repository at this point in the history
…llm-project#11161)

FIX issue vllm-project#9688
vllm-project#11086 vllm-project#12487

---------

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: weilong.yu <weilong.yu@shopee.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
  • Loading branch information
3 people authored and Isotr0py committed Feb 2, 2025
1 parent 66453e2 commit f91599c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 4 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,10 @@ def _dummy_run(self,

self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
if self.lora_config:
# Remove dummy loras.
assert self.lora_manager is not None
self.remove_all_loras()
return

def remove_all_loras(self):
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")

logger.info(msg)

# Final cleanup
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()

return num_gpu_blocks, num_cpu_blocks
Expand Down

0 comments on commit f91599c

Please sign in to comment.