Skip to content

Commit 666cc19

Browse files
committed
Split up the condition
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 1d1c7b3 commit 666cc19

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

vllm/model_executor/models/adapters.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
_T = TypeVar("_T", bound=type[nn.Module])
1010

1111

12+
def _is_paramless(module: nn.Module):
13+
# NOTE: all([]) returns True
14+
return all(False for _ in module.parameters())
15+
16+
1217
def as_embedding_model(cls: _T) -> _T:
1318
"""Subclass an existing vLLM model to support embeddings."""
1419
# Avoid modifying existing embedding models
@@ -69,16 +74,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
6974
# If `*ForCausalLM` defines `load_weights` on the inner model
7075
# and there are no other inner modules with parameters,
7176
# we support loading from both `*Model` and `*ForCausalLM`
72-
if (hasattr(self, "model") and hasattr(self.model, "load_weights")
73-
and all(name == "model" or all(False
74-
for _ in child.parameters())
75-
for name, child in self.named_children())):
76-
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
77-
weights = mapper.apply(weights)
78-
79-
self.model.load_weights(weights)
77+
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
78+
# Whether only `self.model` contains parameters
79+
model_is_only_param = all(
80+
name == "model" or _is_paramless(child)
81+
for name, child in self.named_children())
82+
83+
if model_is_only_param:
84+
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
85+
weights = mapper.apply(weights)
86+
87+
self.model.load_weights(weights)
88+
return
89+
8090
# For most other models
81-
elif hasattr(cls, "load_weights"):
91+
if hasattr(cls, "load_weights"):
8292
cls.load_weights(self, weights) # type: ignore
8393
# Fallback
8494
else:

0 commit comments

Comments
 (0)