Skip to content

Commit

Permalink
[Model] Pipeline parallel support for Qwen2 (vllm-project#6924)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyi authored and kylesayrs committed Aug 17, 2024
1 parent 3173a7b commit daff04d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 27 deletions.
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"GPT2LMHeadModel",
"MixtralForCausalLM",
"NemotronForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
]


Expand Down
57 changes: 45 additions & 12 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand All @@ -49,6 +49,7 @@
from vllm.sequence import IntermediateTensors, SamplerOutput

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers


class Qwen2MLP(nn.Module):
Expand Down Expand Up @@ -227,6 +228,7 @@ def __init__(
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
Expand All @@ -237,10 +239,14 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2DecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.layers",
)

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand All @@ -255,20 +261,30 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down Expand Up @@ -351,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata)
return logits

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def sample(
self,
logits: torch.Tensor,
Expand Down Expand Up @@ -381,6 +411,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -393,7 +425,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
69 changes: 54 additions & 15 deletions vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -52,6 +53,8 @@
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once

from .utils import is_pp_missing_parameter, make_layers


class Qwen2MoeMLP(nn.Module):

Expand Down Expand Up @@ -315,6 +318,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
Expand All @@ -324,13 +328,15 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2MoeDecoderLayer(config=config,
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
Expand All @@ -339,14 +345,25 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down Expand Up @@ -380,7 +397,7 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand All @@ -389,6 +406,20 @@ def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata)
return logits

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def sample(
self,
logits: Optional[torch.Tensor],
Expand Down Expand Up @@ -435,6 +466,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue

Expand All @@ -448,6 +482,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
Expand All @@ -460,6 +497,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
Expand All @@ -474,7 +514,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
else:
name = remapped_kv_scale_name

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit daff04d

Please sign in to comment.