diff --git a/vllm/config.py b/vllm/config.py index de5d0402a1bc7..e065744592378 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -40,6 +40,8 @@ "GPT2LMHeadModel", "MixtralForCausalLM", "NemotronForCausalLM", + "Qwen2ForCausalLM", + "Qwen2MoeForCausalLM", ] diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 35fd6f37589a0..99fdd993943be 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,7 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -49,6 +49,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers class Qwen2MLP(nn.Module): @@ -227,6 +228,7 @@ def __init__( config: Qwen2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -237,10 +239,14 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2DecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -255,20 +261,30 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None else: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -351,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: torch.Tensor, @@ -381,6 +411,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -393,7 +425,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue - + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 2cc2f1440d147..b895788206d10 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -31,7 +31,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -52,6 +53,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import print_warning_once +from .utils import is_pp_missing_parameter, make_layers + class Qwen2MoeMLP(nn.Module): @@ -315,6 +318,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -324,13 +328,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2MoeDecoderLayer(config=config, + layer_idx=int( + prefix.split(".")[-1]), + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -339,14 +345,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -380,7 +397,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -389,6 +406,20 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: Optional[torch.Tensor], @@ -435,6 +466,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue if name not in params_dict: continue @@ -448,6 +482,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -460,6 +497,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( @@ -474,7 +514,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name - param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)