|
27 | 27 | from vllm.sequence import IntermediateTensors |
28 | 28 | from vllm.utils import LayerBlockType |
29 | 29 |
|
30 | | -from .utils import (is_pp_missing_parameter, |
| 30 | +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
31 | 31 | make_empty_intermediate_tensors_factory, make_layers, |
32 | 32 | maybe_prefix) |
33 | 33 |
|
@@ -154,6 +154,26 @@ def forward( |
154 | 154 |
|
155 | 155 | return hidden_states |
156 | 156 |
|
| 157 | + def load_weights(self, weights: Iterable[Tuple[str, |
| 158 | + torch.Tensor]]) -> Set[str]: |
| 159 | + params_dict = dict(self.named_parameters()) |
| 160 | + loaded_params: Set[str] = set() |
| 161 | + for name, loaded_weight in weights: |
| 162 | + if "A_log" in name: |
| 163 | + name = name.replace("A_log", "A") |
| 164 | + # Skip loading extra bias for GPTQ models. |
| 165 | + if name.endswith(".bias") and name not in params_dict: |
| 166 | + continue |
| 167 | + if is_pp_missing_parameter(name, self): |
| 168 | + continue |
| 169 | + |
| 170 | + param = params_dict[name] |
| 171 | + weight_loader = getattr(param, "weight_loader", |
| 172 | + default_weight_loader) |
| 173 | + weight_loader(param, loaded_weight) |
| 174 | + loaded_params.add(name) |
| 175 | + return loaded_params |
| 176 | + |
157 | 177 |
|
158 | 178 | class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, |
159 | 179 | SupportsV0Only): |
@@ -257,20 +277,5 @@ def sample( |
257 | 277 |
|
258 | 278 | def load_weights(self, weights: Iterable[Tuple[str, |
259 | 279 | torch.Tensor]]) -> Set[str]: |
260 | | - params_dict = dict(self.named_parameters()) |
261 | | - loaded_params: Set[str] = set() |
262 | | - for name, loaded_weight in weights: |
263 | | - if "A_log" in name: |
264 | | - name = name.replace("A_log", "A") |
265 | | - # Skip loading extra bias for GPTQ models. |
266 | | - if name.endswith(".bias") and name not in params_dict: |
267 | | - continue |
268 | | - if is_pp_missing_parameter(name, self): |
269 | | - continue |
270 | | - |
271 | | - param = params_dict[name] |
272 | | - weight_loader = getattr(param, "weight_loader", |
273 | | - default_weight_loader) |
274 | | - weight_loader(param, loaded_weight) |
275 | | - loaded_params.add(name) |
276 | | - return loaded_params |
| 280 | + loader = AutoWeightsLoader(self) |
| 281 | + return loader.load_weights(weights) |
0 commit comments