Skip to content

Commit

Permalink
[Models] Support Qwen model with PP (vllm-project#6974)
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
andoorve authored and Alvant committed Oct 26, 2024
1 parent 854de51 commit 0fb98ae
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/source/serving/distributed_serving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ You can also additionally specify :code:`--pipeline-parallel-size` to enable pip
$ --pipeline-parallel-size 2
.. note::
Pipeline parallel is a beta feature. It is only supported for online serving as well as LLaMa, GPT2, and Mixtral style models.
Pipeline parallel is a beta feature. It is only supported for online serving as well as LLaMa, GPT2, Mixtral, Qwen, Qwen2, and Nemotron style models.

Multi-Node Inference and Serving
--------------------------------
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"NemotronForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
]


Expand Down
54 changes: 44 additions & 10 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand All @@ -30,6 +30,8 @@
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once

from .utils import is_pp_missing_parameter, make_layers


class QWenMLP(nn.Module):

Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
Expand All @@ -195,10 +198,10 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.h = nn.ModuleList([
QWenBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: QWenBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

def forward(
Expand All @@ -207,18 +210,29 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
residual = None
for i in range(len(self.h)):
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states

Expand Down Expand Up @@ -250,9 +264,23 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
Expand Down Expand Up @@ -284,6 +312,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -301,6 +332,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"Only text inputs are allowed. Images won't be handled "
"until Qwen-VL models are fully supported.")
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit 0fb98ae

Please sign in to comment.