|
23 | 23 | # limitations under the License. |
24 | 24 | """Inference-only ErineMoE model compatible with HuggingFace weights.""" |
25 | 25 |
|
26 | | -from collections.abc import Iterable |
| 26 | +import typing |
| 27 | +from collections.abc import Callable, Iterable |
27 | 28 | from itertools import islice |
28 | 29 | from typing import Any |
29 | 30 |
|
@@ -139,10 +140,10 @@ def __init__( |
139 | 140 |
|
140 | 141 | # Load balancing settings. |
141 | 142 | vllm_config = get_current_vllm_config() |
142 | | - parallel_config = vllm_config.parallel_config |
| 143 | + eplb_config = vllm_config.parallel_config.eplb_config |
143 | 144 | self.enable_eplb = enable_eplb |
144 | 145 |
|
145 | | - self.n_redundant_experts = parallel_config.num_redundant_experts |
| 146 | + self.n_redundant_experts = eplb_config.num_redundant_experts |
146 | 147 | self.n_logical_experts = self.n_routed_experts |
147 | 148 | self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts |
148 | 149 | self.n_local_physical_experts = self.n_physical_experts // self.ep_size |
@@ -426,8 +427,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
426 | 427 | self.vocab_size = config.vocab_size |
427 | 428 | self.config = config |
428 | 429 | parallel_config = vllm_config.parallel_config |
| 430 | + eplb_config = parallel_config.eplb_config |
429 | 431 | enable_eplb = parallel_config.enable_eplb |
430 | | - self.num_redundant_experts = parallel_config.num_redundant_experts |
| 432 | + |
| 433 | + self.num_redundant_experts = eplb_config.num_redundant_experts |
431 | 434 |
|
432 | 435 | if get_pp_group().is_first_rank: |
433 | 436 | self.embed_tokens = VocabParallelEmbedding( |
@@ -570,20 +573,27 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
570 | 573 |
|
571 | 574 | # Skip loading extra bias for GPTQ models. |
572 | 575 | if ( |
573 | | - name.endswith(".bias") or name.endswith("_bias") |
574 | | - ) and name not in params_dict: |
| 576 | + name_mapped.endswith(".bias") or name_mapped.endswith("_bias") |
| 577 | + ) and name_mapped not in params_dict: |
575 | 578 | continue |
576 | | - param = params_dict[name] |
577 | | - |
578 | | - weight_loader = param.weight_loader |
579 | | - weight_loader( |
| 579 | + param = params_dict[name_mapped] |
| 580 | + # We should ask the weight loader to return success or not |
| 581 | + # here since otherwise we may skip experts with other |
| 582 | + # available replicas. |
| 583 | + weight_loader = typing.cast( |
| 584 | + Callable[..., bool], param.weight_loader |
| 585 | + ) |
| 586 | + success = weight_loader( |
580 | 587 | param, |
581 | 588 | loaded_weight, |
582 | | - name, |
| 589 | + name_mapped, |
583 | 590 | shard_id=shard_id, |
584 | 591 | expert_id=expert_id, |
| 592 | + return_success=True, |
585 | 593 | ) |
586 | | - break |
| 594 | + if success: |
| 595 | + name = name_mapped |
| 596 | + break |
587 | 597 | else: |
588 | 598 | if is_expert_weight: |
589 | 599 | # We've checked that this is an expert weight |
|
0 commit comments