Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 70 additions & 105 deletions vllm/model_executor/layers/linear.py

Large diffs are not rendered by default.

23 changes: 19 additions & 4 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, load_config: LoadConfig):
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
self.tp_disabled_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
Expand Down Expand Up @@ -322,14 +323,24 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit

tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()

for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):

# override tp_size and tp_rank if the module has disabled TP
if any(tp_disabled_module in mapped_weight_name
for tp_disabled_module in self.tp_disabled_modules):
tp_size = 1
tp_rank = 0
else:
tp_size = global_tp_size
tp_rank = global_tp_rank

if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
Expand Down Expand Up @@ -418,12 +429,16 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(rep_name, sub_name))
new_name = name.replace(rep_name, sub_name)
self.target_modules.append(new_name)
if module.disable_tp:
self.tp_disabled_modules.append(new_name)
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
if module.disable_tp:
self.tp_disabled_modules.append(name)
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"):
# TODO: support FusedMoE with prequant and 8bit.
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -435,12 +434,13 @@ def __init__(
self.max_position_embeddings = max_position_embeddings

if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedReplicatedLinear(
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj")
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True)
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
Expand Down
128 changes: 50 additions & 78 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,10 @@
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
Expand Down Expand Up @@ -174,20 +170,22 @@ def __init__(
use_data_parallel: bool = False,
):
super().__init__()
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_fn = SiluAndMul()

def forward(self, x: torch.Tensor):
Expand Down Expand Up @@ -234,48 +232,32 @@ def __init__(
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.tp_rank = (0 if use_data_parallel else
parallel_state.get_tensor_model_parallel_rank())
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)

if use_data_parallel:
self.qkv = ReplicatedLinear(
input_size=embed_dim,
output_size=3 * projection_size,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = ReplicatedLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
disable_tp=use_data_parallel,
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
disable_tp=use_data_parallel,
)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
Expand Down Expand Up @@ -494,41 +476,31 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = d_model
if use_data_parallel:
self.proj = ReplicatedLinear(
input_size=self.hidden_size,
output_size=self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
else:
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[context_dim] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(
self.down_proj = RowParallelLinear(
context_dim,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_fn = SiluAndMul()
self.extra_activation_func = nn.GELU()
Expand Down
62 changes: 25 additions & 37 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
Expand Down Expand Up @@ -178,22 +177,20 @@ def __init__(self,
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up_proj(
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")

cls_down_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down_proj(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel)

self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel)
self.act_fn = act_fn

def forward(self, x: torch.Tensor):
Expand Down Expand Up @@ -243,30 +240,21 @@ def __init__(
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)

if use_data_parallel:
self.qkv = ReplicatedLinear(embed_dim,
self.hidden_size_per_attention_head *
3 * num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")

else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")

cls_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.proj = cls_proj(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)

self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
Expand Down
Loading