3030 is_regex_target_modules ,
3131 parse_fine_tuned_lora_name , replace_submodule )
3232from vllm .model_executor .models import SupportsLoRA , supports_multimodal
33+ from vllm .model_executor .models .interfaces import is_pooling_model
3334from vllm .model_executor .models .module_mapping import MultiModelKeys
3435from vllm .model_executor .models .utils import PPMissingLayer , WeightsMapper
3536from vllm .utils import is_pin_memory_available
@@ -104,6 +105,9 @@ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
104105 """Get LoRA for a given module by name"""
105106 return self .loras .get (module_name , None )
106107
108+ def check_lora_name (self , lora_name : str ) -> bool :
109+ return lora_name in self .loras
110+
107111 # (yard1): TODO see if we can derive target_embedding_padding automatically
108112 @classmethod
109113 def from_lora_tensors (
@@ -335,6 +339,7 @@ def __init__(
335339 # Used for long context lora.
336340 self .scaling_factor_to_offset : Dict [float , int ] = {}
337341 super ().__init__ (model )
342+
338343 self .supported_lora_modules = get_supported_lora_modules (self .model )
339344 assert self .supported_lora_modules , "No supported LoRA modules found in"
340345 f"{ self .model .__class__ .__name__ } ."
@@ -350,6 +355,7 @@ def __init__(
350355 # In case the model only supports LoRA for
351356 # text modules (e.g. ChatGLM)
352357 and hasattr (self .model , "get_mm_mapping" ))
358+ self .is_pooling_model = is_pooling_model (self .model )
353359 self .packed_modules : Dict [str , List [str ]] = {}
354360 self .modules : Dict [str , BaseLayerWithLoRA ] = {}
355361 # Dict instead of a Set for compatibility with LRUCache.
@@ -389,7 +395,7 @@ def activate_adapter(
389395 lora_model .id , index )
390396 self .lora_index_to_id [index ] = lora_model .id
391397 for module_name , module in self .modules .items ():
392- module_lora = lora_model . get_lora ( module_name )
398+ module_lora = self . _get_lora_layer_weights ( lora_model , module_name )
393399 if module_lora :
394400 module_lora .optimize ()
395401 # Bias is not explicitly enabled with the flag enable_lora_bias.
@@ -626,7 +632,7 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
626632 replaced_module : Set [str ] = set ()
627633 has_replacement = False
628634 for r in new_module_names :
629- lora = lora_model . get_lora ( r )
635+ lora = self . _get_lora_layer_weights ( lora_model , r )
630636 replacement_loras .append (lora )
631637 if lora :
632638 has_replacement = True
@@ -637,12 +643,34 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
637643 if replacement_loras [i ]:
638644 continue
639645 replacement_loras [i ] = None
646+ # HACK Temporary solution for the pool model.
647+ if self .is_pooling_model and not lora_model .check_lora_name (
648+ module_name ):
649+ replaced_module_name = module_name .replace ("model." , "" )
650+ if lora_model .check_lora_name (module_name ):
651+ module_name = replaced_module_name
640652 lora_model .loras [module_name ] = PackedLoRALayerWeights .pack (
641653 replacement_loras )
642654 # Remove the modules that have been replaced.
643655 for module in replaced_module :
644656 lora_model .loras .pop (module , None )
645657
658+ def _get_lora_layer_weights (
659+ self , lora_model : LoRAModel ,
660+ module_name : str ) -> Optional [LoRALayerWeights ]:
661+ org_module_name = module_name
662+ if self .is_pooling_model and not lora_model .check_lora_name (
663+ module_name ):
664+ # If it's a pool model, and the layer name is not found,
665+ # remove the prefix 'model.' and search again.
666+ module_name = module_name .replace ("model." , "" )
667+ if lora_model .check_lora_name (module_name ):
668+ org_module_name = module_name
669+ logger .info_once (
670+ "For the pool model, successfully loaded the LoRA weights "
671+ "after removing the prefix 'model.'." )
672+ return lora_model .get_lora (org_module_name )
673+
646674 def deactivate_adapter (self , adapter_id : int ) -> bool :
647675 return deactivate_adapter (adapter_id , self ._active_adapters ,
648676 self ._deactivate_adapter )
0 commit comments