Skip to content

Commit 5a86416

Browse files
authored
[VLM] Add PP support and fix GPTQ inference for Ovis models (vllm-project#18958)
Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent f49239c commit 5a86416

File tree

5 files changed

+145
-91
lines changed

5 files changed

+145
-91
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ Specified using `--task generate`.
538538
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
539539
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
540540
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
541-
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | | ✅︎ |
541+
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
542542
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
543543
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
544544
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

tests/distributed/test_pipeline_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def iter_params(self, model_id: str):
227227
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
228228
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
229229
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
230+
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
230231
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
231232
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
232233
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),

vllm/model_executor/models/aimv2.py

Lines changed: 120 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22

33
# A modified implementation of the AIMv2 Transformer
44
# inserted here also the image tokenizer used by Ovis2
5+
from collections.abc import Iterable
56
from typing import Optional
67

78
import torch
89
import torch.nn as nn
9-
from torch.nn import functional as F
1010

11+
from vllm.attention.layer import MultiHeadAttention
12+
from vllm.distributed import get_tensor_model_parallel_world_size
13+
from vllm.distributed.utils import divide
14+
from vllm.model_executor.layers.activation import SiluAndMul
1115
from vllm.model_executor.layers.layernorm import RMSNorm
12-
from vllm.model_executor.layers.linear import ReplicatedLinear
16+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
17+
QKVParallelLinear,
18+
RowParallelLinear)
1319
from vllm.model_executor.layers.quantization.base_config import (
1420
QuantizationConfig)
21+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1522
from vllm.transformers_utils.configs.ovis import AIMv2Config
1623

1724

@@ -24,29 +31,27 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
2431
in_features = config.hidden_size
2532
bias = config.use_bias
2633

27-
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
28-
self.fc1 = ReplicatedLinear(in_features,
29-
hidden_features,
30-
bias=bias,
31-
quant_config=quant_config,
32-
prefix=f"{prefix}.fc1")
33-
self.fc2 = ReplicatedLinear(hidden_features,
34-
in_features,
35-
bias=bias,
36-
quant_config=quant_config,
37-
prefix=f"{prefix}.fc2")
38-
self.fc3 = ReplicatedLinear(in_features,
39-
hidden_features,
40-
bias=bias,
41-
quant_config=quant_config,
42-
prefix=f"{prefix}.fc3")
34+
self.fc13 = MergedColumnParallelLinear(
35+
in_features,
36+
[hidden_features] * 2,
37+
bias=bias,
38+
quant_config=quant_config,
39+
prefix=f"{prefix}.fc13",
40+
)
41+
self.fc2 = RowParallelLinear(
42+
input_size=hidden_features,
43+
output_size=in_features,
44+
bias=bias,
45+
quant_config=quant_config,
46+
prefix=f"{prefix}.fc2",
47+
)
48+
self.act_fn = SiluAndMul()
4349

4450
def forward(self, x: torch.Tensor) -> torch.Tensor:
45-
x_parallel, _ = self.fc1(x)
46-
gate, _ = self.fc3(x)
47-
x_parallel = F.silu(x_parallel) * gate
48-
out, _ = self.fc2(x_parallel)
49-
return out
51+
x, _ = self.fc13(x)
52+
x = self.act_fn(x)
53+
x, _ = self.fc2(x)
54+
return x
5055

5156

5257
class AIMv2PatchEmbed(nn.Module):
@@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module):
9095
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
9196
prefix: str):
9297
super().__init__()
93-
dim = config.hidden_size
94-
95-
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
98+
self.config = config
99+
self.embed_dim = config.hidden_size
96100
self.num_heads = config.num_attention_heads
97-
self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias)
98-
# self.qkv = QKVParallelLinear(
99-
# hidden_size=dim,
100-
# head_size=dim // config.num_attention_heads,
101-
# total_num_heads=config.num_attention_heads,
102-
# bias=config.qkv_bias,
103-
# quant_config=quant_config,
104-
# prefix=f"{prefix}.qkv")
105-
self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias)
106-
# self.proj = RowParallelLinear(input_size=dim,
107-
# output_size=dim,
108-
# bias = config.use_bias,
109-
# quant_config=quant_config,
110-
# prefix=f"{prefix}.proj")
111-
112-
def forward( # todo might implement multiple attn implementations
113-
self,
114-
x: torch.Tensor,
115-
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
116-
B, N, C = x.shape
117-
qkv, _ = self.qkv(x)
101+
self.head_dim = self.embed_dim // self.num_heads
102+
if self.head_dim * self.num_heads != self.embed_dim:
103+
raise ValueError(
104+
"embed_dim must be divisible by num_heads "
105+
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
106+
f" {self.num_heads}).")
107+
self.scale = self.head_dim**-0.5
108+
109+
self.qkv = QKVParallelLinear(
110+
hidden_size=self.embed_dim,
111+
head_size=self.head_dim,
112+
total_num_heads=self.num_heads,
113+
bias=config.qkv_bias,
114+
quant_config=quant_config,
115+
prefix=f"{prefix}.qkv",
116+
)
117+
118+
self.proj = RowParallelLinear(
119+
input_size=self.embed_dim,
120+
output_size=self.embed_dim,
121+
bias=config.use_bias,
122+
quant_config=quant_config,
123+
prefix=f"{prefix}.proj",
124+
)
125+
126+
self.tp_size = get_tensor_model_parallel_world_size()
127+
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
118128

119-
qkv = qkv.reshape(B, N, 3, self.num_heads,
120-
C // self.num_heads).permute(2, 0, 3, 1, 4)
129+
self.attn = MultiHeadAttention(self.num_heads_per_partition,
130+
self.head_dim, self.scale)
121131

122-
q, k, v = qkv.unbind(0)
132+
def forward(self, x: torch.Tensor) -> torch.Tensor:
133+
qkv, _ = self.qkv(x)
134+
q, k, v = qkv.chunk(3, dim=-1)
123135

124-
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
125-
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
136+
x = self.attn(q, k, v)
126137
x, _ = self.proj(x)
127138
return x
128139

@@ -141,37 +152,40 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
141152
prefix=f"{prefix}.mlp")
142153
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
143154

144-
def forward(self,
145-
x: torch.Tensor,
146-
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
147-
x = x + self.attn(self.norm_1.forward_native(x), mask)
155+
def forward(self, x: torch.Tensor) -> torch.Tensor:
156+
x = x + self.attn(self.norm_1.forward_native(x))
148157
x = x + self.mlp(self.norm_2.forward_native(x))
149158
return x
150159

151160

152161
class AIMv2Transformer(nn.Module):
153162

154-
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
155-
prefix: str):
163+
def __init__(
164+
self,
165+
config: AIMv2Config,
166+
quant_config: QuantizationConfig,
167+
*,
168+
require_post_norm: Optional[bool] = None,
169+
prefix: str = "",
170+
):
156171
super().__init__()
157172

158173
self.blocks = nn.ModuleList([
159174
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
160175
for i in range(config.num_hidden_layers)
161176
])
162-
self.post_trunk_norm = RMSNorm(config.hidden_size,
163-
eps=config.rms_norm_eps)
177+
if require_post_norm:
178+
self.post_trunk_norm = RMSNorm(config.hidden_size,
179+
eps=config.rms_norm_eps)
180+
else:
181+
self.post_trunk_norm = None
164182

165-
def forward(
166-
self,
167-
tokens: torch.Tensor,
168-
mask: Optional[torch.Tensor] = None,
169-
) -> torch.Tensor:
183+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
170184
# they take the -1 as the ref embeddings, like a clip skip
171185
for block in self.blocks:
172-
tokens = block(tokens, mask)
173-
# NO NORM IN THE OG IMPLEMENTATION
174-
# tokens = self.post_trunk_norm(tokens)
186+
tokens = block(tokens)
187+
if self.post_trunk_norm is not None:
188+
tokens = self.post_trunk_norm(tokens)
175189
return tokens
176190

177191

@@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module):
180194
def __init__(self,
181195
config: AIMv2Config,
182196
quant_config: QuantizationConfig,
197+
*,
198+
require_post_norm: Optional[bool] = None,
183199
prefix: str = ""):
184200
super().__init__()
185201
self.preprocessor = AIMv2ViTPreprocessor(config)
186202
self.trunk = AIMv2Transformer(config,
187203
quant_config=quant_config,
204+
require_post_norm=require_post_norm,
188205
prefix=f"{prefix}.trunk")
189206

190-
def forward(
191-
self,
192-
pixel_values: torch.Tensor,
193-
mask: Optional[torch.Tensor] = None,
194-
) -> torch.Tensor:
207+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
195208

196209
x = self.preprocessor(pixel_values)
197-
x = self.trunk(x, mask)
210+
x = self.trunk(x)
198211

199212
return x
213+
214+
def load_weights(self, weights: Iterable[tuple[str,
215+
torch.Tensor]]) -> set[str]:
216+
stacked_params_mapping = [
217+
# (param_name, shard_name, shard_id)
218+
(".fc13", ".fc1", 0),
219+
(".fc13", ".fc3", 1),
220+
]
221+
params_dict = dict(self.named_parameters())
222+
loaded_params: set[str] = set()
223+
224+
for name, loaded_weight in weights:
225+
# post_layernorm is optional in SiglipVisionModel
226+
if (name.startswith("trunk.post_trunk_norm")
227+
and self.trunk.post_trunk_norm is None):
228+
continue
229+
230+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
231+
if weight_name not in name:
232+
continue
233+
name = name.replace(weight_name, param_name)
234+
235+
param = params_dict[name]
236+
weight_loader = param.weight_loader
237+
weight_loader(param, loaded_weight, shard_id)
238+
break
239+
else:
240+
param = params_dict[name]
241+
weight_loader = getattr(param, "weight_loader",
242+
default_weight_loader)
243+
weight_loader(param, loaded_weight)
244+
loaded_params.add(name)
245+
return loaded_params

vllm/model_executor/models/clip.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def __init__(
106106
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
107107
f" {self.num_heads}).")
108108
self.scale = self.head_dim**-0.5
109-
self.dropout = config.attention_dropout
110109

111110
self.qkv_proj = QKVParallelLinear(
112111
hidden_size=self.embed_dim,
@@ -129,10 +128,6 @@ def __init__(
129128
self.attn = MultiHeadAttention(self.num_heads_per_partition,
130129
self.head_dim, self.scale)
131130

132-
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
133-
return tensor.view(bsz, seq_len, self.num_heads,
134-
self.head_dim).transpose(1, 2).contiguous()
135-
136131
def forward(
137132
self,
138133
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)