Skip to content

Commit 82cabf5

Browse files
authored
[Misc] Delete unused LoRA modules (#13151)
1 parent 314cfad commit 82cabf5

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

tests/lora/test_lora_manager.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -606,27 +606,33 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
606606

607607
assert isinstance(model.get_submodule("gate_up_proj"),
608608
MergedColumnParallelLinearWithLoRA)
609+
# Verify packed lora is correct
610+
model_lora_clone = model_lora.clone(1)
611+
model_lora_clone1 = model_lora1.clone(1)
609612
assert manager.add_adapter(model_lora)
610613
assert manager.add_adapter(model_lora1)
611614

615+
assert model_lora.get_lora("gate_proj") is None
616+
assert model_lora.get_lora("up_proj") is None
617+
assert model_lora1.get_lora("up_proj") is None
612618
packed_lora = model_lora.get_lora("gate_up_proj")
613619
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
614620

615621
torch.testing.assert_close(packed_lora.lora_a[0],
616-
model_lora.get_lora("gate_proj").lora_a)
622+
model_lora_clone.get_lora("gate_proj").lora_a)
617623
torch.testing.assert_close(packed_lora.lora_b[0],
618-
model_lora.get_lora("gate_proj").lora_b)
624+
model_lora_clone.get_lora("gate_proj").lora_b)
619625
torch.testing.assert_close(packed_lora.lora_a[1],
620-
model_lora.get_lora("up_proj").lora_a)
626+
model_lora_clone.get_lora("up_proj").lora_a)
621627
torch.testing.assert_close(packed_lora.lora_b[1],
622-
model_lora.get_lora("up_proj").lora_b)
628+
model_lora_clone.get_lora("up_proj").lora_b)
623629

624630
packed_lora1 = model_lora1.get_lora("gate_up_proj")
625631
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
626632

627633
assert packed_lora1.lora_a[0] is None
628634
assert packed_lora1.lora_b[0] is None
629635
torch.testing.assert_close(packed_lora1.lora_a[1],
630-
model_lora1.get_lora("up_proj").lora_a)
636+
model_lora_clone1.get_lora("up_proj").lora_a)
631637
torch.testing.assert_close(packed_lora1.lora_b[1],
632-
model_lora1.get_lora("up_proj").lora_b)
638+
model_lora_clone1.get_lora("up_proj").lora_b)

vllm/lora/models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import os
66
import re
77
from dataclasses import dataclass, field
8-
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
8+
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
9+
Union)
910

1011
import safetensors.torch
1112
import torch
@@ -619,12 +620,14 @@ def _register_packed_modules(self, module_full_name: str) -> None:
619620
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
620621
for module_name, new_module_names in self.packed_modules.items():
621622
replacement_loras: List[Optional[LoRALayerWeights]] = []
623+
replaced_module: Set[str] = set()
622624
has_replacement = False
623625
for r in new_module_names:
624626
lora = lora_model.get_lora(r)
625627
replacement_loras.append(lora)
626628
if lora:
627629
has_replacement = True
630+
replaced_module.add(r)
628631
if not has_replacement:
629632
continue
630633
for i in range(len(replacement_loras)):
@@ -633,6 +636,9 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
633636
replacement_loras[i] = None
634637
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
635638
replacement_loras)
639+
# Remove the modules that have been replaced.
640+
for module in replaced_module:
641+
lora_model.loras.pop(module, None)
636642

637643
def deactivate_adapter(self, adapter_id: int) -> bool:
638644
return deactivate_adapter(adapter_id, self._active_adapters,

vllm/lora/punica_wrapper/punica_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
147147
dtype=torch.long,
148148
device=device)
149149

150-
# 5 is the number of indicies tensors.
150+
# 5 is the number of indices tensors.
151151
# base_indices, sampler_indices, sampler_indices_padded,
152152
# embeddings_indices,long_lora_indices
153153
self.indices_len: List[Optional[int]] = [None] * 5

0 commit comments

Comments
 (0)