99_T = TypeVar ("_T" , bound = type [nn .Module ])
1010
1111
12- def _is_paramless (module : nn .Module ):
13- # NOTE: all([]) returns True
14- return all (False for _ in module .parameters ())
15-
16-
1712def as_embedding_model (cls : _T ) -> _T :
1813 """Subclass an existing vLLM model to support embeddings."""
1914 # Avoid modifying existing embedding models
@@ -40,24 +35,21 @@ def __init__(
4035 super ().__init__ (vllm_config = vllm_config , prefix = prefix , ** kwargs )
4136
4237 # These are not used in embedding models
43- if hasattr (self , "lm_head" ):
44- del self .lm_head
45- if hasattr (self , "logits_processor" ):
46- del self .logits_processor
38+ for attr in ("lm_head" , "logits_processor" ):
39+ if hasattr (self , attr ):
40+ delattr (self , attr )
4741
4842 pooler_config = vllm_config .model_config .pooler_config
4943 assert pooler_config is not None
5044
5145 # If the model already defines a pooler instance, don't overwrite it
5246 if not getattr (self , "_pooler" , None ):
53- pooler = Pooler .from_config_with_defaults (
47+ self . _pooler = Pooler .from_config_with_defaults (
5448 pooler_config ,
5549 pooling_type = PoolingType .LAST ,
5650 normalize = True ,
5751 softmax = False ,
5852 )
59- assert pooler is not None
60- self ._pooler = pooler
6153
6254 def pooler (
6355 self ,
@@ -77,7 +69,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
7769 if hasattr (self , "model" ) and hasattr (self .model , "load_weights" ):
7870 # Whether only `self.model` contains parameters
7971 model_is_only_param = all (
80- name == "model" or _is_paramless (child )
72+ name == "model" or next (child . parameters (), None ) is None
8173 for name , child in self .named_children ())
8274
8375 if model_is_only_param :
0 commit comments