Skip to content

Commit db7c8ca

Browse files
authored
[Misc] Embedding model support LoRA (#14935)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent f863ffc commit db7c8ca

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

vllm/lora/models.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
is_regex_target_modules,
3131
parse_fine_tuned_lora_name, replace_submodule)
3232
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
33+
from vllm.model_executor.models.interfaces import is_pooling_model
3334
from vllm.model_executor.models.module_mapping import MultiModelKeys
3435
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
3536
from vllm.utils import is_pin_memory_available
@@ -104,6 +105,9 @@ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
104105
"""Get LoRA for a given module by name"""
105106
return self.loras.get(module_name, None)
106107

108+
def check_lora_name(self, lora_name: str) -> bool:
109+
return lora_name in self.loras
110+
107111
# (yard1): TODO see if we can derive target_embedding_padding automatically
108112
@classmethod
109113
def from_lora_tensors(
@@ -335,6 +339,7 @@ def __init__(
335339
# Used for long context lora.
336340
self.scaling_factor_to_offset: Dict[float, int] = {}
337341
super().__init__(model)
342+
338343
self.supported_lora_modules = get_supported_lora_modules(self.model)
339344
assert self.supported_lora_modules, "No supported LoRA modules found in"
340345
f"{self.model.__class__.__name__}."
@@ -350,6 +355,7 @@ def __init__(
350355
# In case the model only supports LoRA for
351356
# text modules (e.g. ChatGLM)
352357
and hasattr(self.model, "get_mm_mapping"))
358+
self.is_pooling_model = is_pooling_model(self.model)
353359
self.packed_modules: Dict[str, List[str]] = {}
354360
self.modules: Dict[str, BaseLayerWithLoRA] = {}
355361
# Dict instead of a Set for compatibility with LRUCache.
@@ -389,7 +395,7 @@ def activate_adapter(
389395
lora_model.id, index)
390396
self.lora_index_to_id[index] = lora_model.id
391397
for module_name, module in self.modules.items():
392-
module_lora = lora_model.get_lora(module_name)
398+
module_lora = self._get_lora_layer_weights(lora_model, module_name)
393399
if module_lora:
394400
module_lora.optimize()
395401
# Bias is not explicitly enabled with the flag enable_lora_bias.
@@ -626,7 +632,7 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
626632
replaced_module: Set[str] = set()
627633
has_replacement = False
628634
for r in new_module_names:
629-
lora = lora_model.get_lora(r)
635+
lora = self._get_lora_layer_weights(lora_model, r)
630636
replacement_loras.append(lora)
631637
if lora:
632638
has_replacement = True
@@ -637,12 +643,34 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
637643
if replacement_loras[i]:
638644
continue
639645
replacement_loras[i] = None
646+
# HACK Temporary solution for the pool model.
647+
if self.is_pooling_model and not lora_model.check_lora_name(
648+
module_name):
649+
replaced_module_name = module_name.replace("model.", "")
650+
if lora_model.check_lora_name(module_name):
651+
module_name = replaced_module_name
640652
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
641653
replacement_loras)
642654
# Remove the modules that have been replaced.
643655
for module in replaced_module:
644656
lora_model.loras.pop(module, None)
645657

658+
def _get_lora_layer_weights(
659+
self, lora_model: LoRAModel,
660+
module_name: str) -> Optional[LoRALayerWeights]:
661+
org_module_name = module_name
662+
if self.is_pooling_model and not lora_model.check_lora_name(
663+
module_name):
664+
# If it's a pool model, and the layer name is not found,
665+
# remove the prefix 'model.' and search again.
666+
module_name = module_name.replace("model.", "")
667+
if lora_model.check_lora_name(module_name):
668+
org_module_name = module_name
669+
logger.info_once(
670+
"For the pool model, successfully loaded the LoRA weights "
671+
"after removing the prefix 'model.'.")
672+
return lora_model.get_lora(org_module_name)
673+
646674
def deactivate_adapter(self, adapter_id: int) -> bool:
647675
return deactivate_adapter(adapter_id, self._active_adapters,
648676
self._deactivate_adapter)

0 commit comments

Comments
 (0)