|
9 | 9 | _T = TypeVar("_T", bound=type[nn.Module]) |
10 | 10 |
|
11 | 11 |
|
| 12 | +def _is_paramless(module: nn.Module): |
| 13 | + # NOTE: all([]) returns True |
| 14 | + return all(False for _ in module.parameters()) |
| 15 | + |
| 16 | + |
12 | 17 | def as_embedding_model(cls: _T) -> _T: |
13 | 18 | """Subclass an existing vLLM model to support embeddings.""" |
14 | 19 | # Avoid modifying existing embedding models |
@@ -69,16 +74,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
69 | 74 | # If `*ForCausalLM` defines `load_weights` on the inner model |
70 | 75 | # and there are no other inner modules with parameters, |
71 | 76 | # we support loading from both `*Model` and `*ForCausalLM` |
72 | | - if (hasattr(self, "model") and hasattr(self.model, "load_weights") |
73 | | - and all(name == "model" or all(False |
74 | | - for _ in child.parameters()) |
75 | | - for name, child in self.named_children())): |
76 | | - mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) |
77 | | - weights = mapper.apply(weights) |
78 | | - |
79 | | - self.model.load_weights(weights) |
| 77 | + if hasattr(self, "model") and hasattr(self.model, "load_weights"): |
| 78 | + # Whether only `self.model` contains parameters |
| 79 | + model_is_only_param = all( |
| 80 | + name == "model" or _is_paramless(child) |
| 81 | + for name, child in self.named_children()) |
| 82 | + |
| 83 | + if model_is_only_param: |
| 84 | + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) |
| 85 | + weights = mapper.apply(weights) |
| 86 | + |
| 87 | + self.model.load_weights(weights) |
| 88 | + return |
| 89 | + |
80 | 90 | # For most other models |
81 | | - elif hasattr(cls, "load_weights"): |
| 91 | + if hasattr(cls, "load_weights"): |
82 | 92 | cls.load_weights(self, weights) # type: ignore |
83 | 93 | # Fallback |
84 | 94 | else: |
|
0 commit comments