Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 20 additions & 65 deletions vllm/model_executor/models/idefics2_vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -139,37 +138,22 @@ def __init__(
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size

if use_data_parallel:
self.q_size = self.num_heads * self.head_dim
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
Expand Down Expand Up @@ -201,23 +185,21 @@ def __init__(
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -389,30 +371,6 @@ def forward(
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state

def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
Expand All @@ -425,9 +383,6 @@ def load_weights(self, weights: Iterable[tuple[str,
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)

if self.use_data_parallel:
weights = self._consolidate_qkv_weights(weights)

for name, loaded_weight in weights:
# skip pooling header
if name.startswith("head."):
Expand Down
32 changes: 13 additions & 19 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,21 @@ def __init__(
use_data_parallel: bool = False,
):
super().__init__()
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
self.fc1 = ColumnParallelLinear(
input_size=input_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear
self.fc2 = cls_fc2(
self.fc2 = RowParallelLinear(
input_size=intermediate_size,
output_size=output_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
self.activation_fn = nn.GELU()
self.output_activation = output_activation
Expand Down Expand Up @@ -419,20 +418,15 @@ def __init__(
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
stride=config.patch_size)
params = {
"input_size":
config.num_channels * kernel_size[0] * kernel_size[1],
"output_size": config.hidden_size,
"bias": False,
"quant_config": quant_config,
"prefix": f"{prefix}.linear",
}
if use_data_parallel:
cls = ReplicatedLinear
else:
cls = ColumnParallelLinear
params["gather_output"] = True
self.linear = cls(**params)
self.linear = ColumnParallelLinear(
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
output_size=config.hidden_size,
bias=False,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.linear",
disable_tp=use_data_parallel,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
Expand Down
43 changes: 21 additions & 22 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -510,32 +509,32 @@ def __init__(
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_q = norm_layer(context_dim)

cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.mlp = nn.ModuleList([
cls_fc1(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
self.mlp = nn.Sequential(
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0",
return_bias=False,
disable_tp=use_data_parallel,
),
nn.GELU(),
cls_fc2(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2"),
])
RowParallelLinear(
self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.2",
return_bias=False,
disable_tp=use_data_parallel,
),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)

mlp_fc1, mlp_act, mlp_fc2 = self.mlp
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
out = self.mlp(x)
return out


Expand Down