Skip to content

Commit 77f8001

Browse files
authored
[Model][Bugfix] fix pipeline parallelism support for NemotronH (#27968)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
1 parent 300a265 commit 77f8001

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

vllm/model_executor/models/nemotron_h.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import typing
2222
from collections.abc import Callable, Iterable
23+
from itertools import islice
2324

2425
import torch
2526
from torch import nn
@@ -549,7 +550,7 @@ def get_layer(prefix: str):
549550
self.start_layer, self.end_layer, self.layers = make_layers(
550551
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
551552
)
552-
self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory(
553+
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
553554
["hidden_states", "residual"], config.hidden_size
554555
)
555556

@@ -564,7 +565,7 @@ def forward(
564565
positions: torch.Tensor,
565566
intermediate_tensors: IntermediateTensors | None = None,
566567
inputs_embeds: torch.Tensor | None = None,
567-
) -> torch.Tensor:
568+
) -> torch.Tensor | IntermediateTensors:
568569
if get_pp_group().is_first_rank:
569570
if inputs_embeds is not None:
570571
hidden_states = inputs_embeds
@@ -576,8 +577,7 @@ def forward(
576577
hidden_states = intermediate_tensors["hidden_states"]
577578
residual = intermediate_tensors["residual"]
578579

579-
residual = None
580-
for i, layer in enumerate(self.layers):
580+
for layer in islice(self.layers, self.start_layer, self.end_layer):
581581
hidden_states, residual = layer(
582582
positions=positions,
583583
hidden_states=hidden_states,
@@ -633,6 +633,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
633633
if name.endswith(".bias") and name not in params_dict:
634634
continue
635635

636+
if is_pp_missing_parameter(name, self):
637+
continue
638+
636639
param = params_dict[name]
637640
weight_loader = param.weight_loader
638641
weight_loader(param, loaded_weight, shard_id)
@@ -678,6 +681,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
678681
if is_expert_weight:
679682
continue
680683

684+
if is_pp_missing_parameter(name, self):
685+
continue
686+
681687
param = params_dict[name]
682688
weight_loader = getattr(
683689
param, "weight_loader", default_weight_loader
@@ -792,7 +798,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
792798
self.unpadded_vocab_size, config.vocab_size
793799
)
794800

795-
self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors
801+
self.make_empty_intermediate_tensors = (
802+
self.model.make_empty_intermediate_tensors
803+
)
796804

797805
# Set MoE hyperparameters
798806
if self.model.has_moe:

0 commit comments

Comments
 (0)