2020
2121import typing
2222from collections .abc import Callable , Iterable
23+ from itertools import islice
2324
2425import torch
2526from torch import nn
@@ -549,7 +550,7 @@ def get_layer(prefix: str):
549550 self .start_layer , self .end_layer , self .layers = make_layers (
550551 len (config .hybrid_override_pattern ), get_layer , prefix = f"{ prefix } .layers"
551552 )
552- self .make_empty_intmd_tensors = make_empty_intermediate_tensors_factory (
553+ self .make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory (
553554 ["hidden_states" , "residual" ], config .hidden_size
554555 )
555556
@@ -564,7 +565,7 @@ def forward(
564565 positions : torch .Tensor ,
565566 intermediate_tensors : IntermediateTensors | None = None ,
566567 inputs_embeds : torch .Tensor | None = None ,
567- ) -> torch .Tensor :
568+ ) -> torch .Tensor | IntermediateTensors :
568569 if get_pp_group ().is_first_rank :
569570 if inputs_embeds is not None :
570571 hidden_states = inputs_embeds
@@ -576,8 +577,7 @@ def forward(
576577 hidden_states = intermediate_tensors ["hidden_states" ]
577578 residual = intermediate_tensors ["residual" ]
578579
579- residual = None
580- for i , layer in enumerate (self .layers ):
580+ for layer in islice (self .layers , self .start_layer , self .end_layer ):
581581 hidden_states , residual = layer (
582582 positions = positions ,
583583 hidden_states = hidden_states ,
@@ -633,6 +633,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
633633 if name .endswith (".bias" ) and name not in params_dict :
634634 continue
635635
636+ if is_pp_missing_parameter (name , self ):
637+ continue
638+
636639 param = params_dict [name ]
637640 weight_loader = param .weight_loader
638641 weight_loader (param , loaded_weight , shard_id )
@@ -678,6 +681,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
678681 if is_expert_weight :
679682 continue
680683
684+ if is_pp_missing_parameter (name , self ):
685+ continue
686+
681687 param = params_dict [name ]
682688 weight_loader = getattr (
683689 param , "weight_loader" , default_weight_loader
@@ -792,7 +798,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
792798 self .unpadded_vocab_size , config .vocab_size
793799 )
794800
795- self .make_empty_intmd_tensors = self .model .make_empty_intmd_tensors
801+ self .make_empty_intermediate_tensors = (
802+ self .model .make_empty_intermediate_tensors
803+ )
796804
797805 # Set MoE hyperparameters
798806 if self .model .has_moe :
0 commit comments