Skip to content

Commit ec0fbf7

Browse files
committed
Simplify code
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent f73282e commit ec0fbf7

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ def from_config_with_defaults(
6060
softmax: bool,
6161
step_tag_id: Optional[int] = None,
6262
returned_token_ids: Optional[List[int]] = None,
63-
) -> Optional["Pooler"]:
64-
if pooler_config is None:
65-
return None
63+
) -> "Pooler":
6664
return cls(
6765
pooling_type=PoolingType[pooler_config.pooling_type]
6866
if pooler_config.pooling_type is not None else pooling_type,

vllm/model_executor/models/adapters.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99
_T = TypeVar("_T", bound=type[nn.Module])
1010

1111

12-
def _is_paramless(module: nn.Module):
13-
# NOTE: all([]) returns True
14-
return all(False for _ in module.parameters())
15-
16-
1712
def as_embedding_model(cls: _T) -> _T:
1813
"""Subclass an existing vLLM model to support embeddings."""
1914
# Avoid modifying existing embedding models
@@ -40,24 +35,21 @@ def __init__(
4035
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
4136

4237
# These are not used in embedding models
43-
if hasattr(self, "lm_head"):
44-
del self.lm_head
45-
if hasattr(self, "logits_processor"):
46-
del self.logits_processor
38+
for attr in ("lm_head", "logits_processor"):
39+
if hasattr(self, attr):
40+
delattr(self, attr)
4741

4842
pooler_config = vllm_config.model_config.pooler_config
4943
assert pooler_config is not None
5044

5145
# If the model already defines a pooler instance, don't overwrite it
5246
if not getattr(self, "_pooler", None):
53-
pooler = Pooler.from_config_with_defaults(
47+
self._pooler = Pooler.from_config_with_defaults(
5448
pooler_config,
5549
pooling_type=PoolingType.LAST,
5650
normalize=True,
5751
softmax=False,
5852
)
59-
assert pooler is not None
60-
self._pooler = pooler
6153

6254
def pooler(
6355
self,
@@ -77,7 +69,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
7769
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
7870
# Whether only `self.model` contains parameters
7971
model_is_only_param = all(
80-
name == "model" or _is_paramless(child)
72+
name == "model" or next(child.parameters(), None) is None
8173
for name, child in self.named_children())
8274

8375
if model_is_only_param:

0 commit comments

Comments
 (0)