Skip to content

Commit c817c3d

Browse files
CSWYF3634076xuebwang-amd
authored andcommitted
[Model][Bugfix]fix ernie45 load failed due to ernie45 eplb code (vllm-project#26684)
Signed-off-by: wangyafeng <wangyafeng@baidu.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 108cd83 commit c817c3d

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

vllm/model_executor/models/ernie45_moe.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
# limitations under the License.
2424
"""Inference-only ErineMoE model compatible with HuggingFace weights."""
2525

26-
from collections.abc import Iterable
26+
import typing
27+
from collections.abc import Callable, Iterable
2728
from itertools import islice
2829
from typing import Any
2930

@@ -139,10 +140,10 @@ def __init__(
139140

140141
# Load balancing settings.
141142
vllm_config = get_current_vllm_config()
142-
parallel_config = vllm_config.parallel_config
143+
eplb_config = vllm_config.parallel_config.eplb_config
143144
self.enable_eplb = enable_eplb
144145

145-
self.n_redundant_experts = parallel_config.num_redundant_experts
146+
self.n_redundant_experts = eplb_config.num_redundant_experts
146147
self.n_logical_experts = self.n_routed_experts
147148
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
148149
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
@@ -426,8 +427,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
426427
self.vocab_size = config.vocab_size
427428
self.config = config
428429
parallel_config = vllm_config.parallel_config
430+
eplb_config = parallel_config.eplb_config
429431
enable_eplb = parallel_config.enable_eplb
430-
self.num_redundant_experts = parallel_config.num_redundant_experts
432+
433+
self.num_redundant_experts = eplb_config.num_redundant_experts
431434

432435
if get_pp_group().is_first_rank:
433436
self.embed_tokens = VocabParallelEmbedding(
@@ -570,20 +573,27 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
570573

571574
# Skip loading extra bias for GPTQ models.
572575
if (
573-
name.endswith(".bias") or name.endswith("_bias")
574-
) and name not in params_dict:
576+
name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
577+
) and name_mapped not in params_dict:
575578
continue
576-
param = params_dict[name]
577-
578-
weight_loader = param.weight_loader
579-
weight_loader(
579+
param = params_dict[name_mapped]
580+
# We should ask the weight loader to return success or not
581+
# here since otherwise we may skip experts with other
582+
# available replicas.
583+
weight_loader = typing.cast(
584+
Callable[..., bool], param.weight_loader
585+
)
586+
success = weight_loader(
580587
param,
581588
loaded_weight,
582-
name,
589+
name_mapped,
583590
shard_id=shard_id,
584591
expert_id=expert_id,
592+
return_success=True,
585593
)
586-
break
594+
if success:
595+
name = name_mapped
596+
break
587597
else:
588598
if is_expert_weight:
589599
# We've checked that this is an expert weight

0 commit comments

Comments
 (0)