Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ Specified using `--task generate`.
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down
1 change: 1 addition & 0 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def iter_params(self, model_id: str):
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
Expand Down
194 changes: 120 additions & 74 deletions vllm/model_executor/models/aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@

# A modified implementation of the AIMv2 Transformer
# inserted here also the image tokenizer used by Ovis2
from collections.abc import Iterable
from typing import Optional

import torch
import torch.nn as nn
from torch.nn import functional as F

from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.ovis import AIMv2Config


Expand All @@ -24,29 +31,27 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
in_features = config.hidden_size
bias = config.use_bias

# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
self.fc1 = ReplicatedLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = ReplicatedLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
self.fc3 = ReplicatedLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc3")
self.fc13 = MergedColumnParallelLinear(
in_features,
[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc13",
)
self.fc2 = RowParallelLinear(
input_size=hidden_features,
output_size=in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.act_fn = SiluAndMul()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
gate, _ = self.fc3(x)
x_parallel = F.silu(x_parallel) * gate
out, _ = self.fc2(x_parallel)
return out
x, _ = self.fc13(x)
x = self.act_fn(x)
x, _ = self.fc2(x)
return x


class AIMv2PatchEmbed(nn.Module):
Expand Down Expand Up @@ -90,39 +95,45 @@ class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
super().__init__()
dim = config.hidden_size

# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias)
# self.qkv = QKVParallelLinear(
# hidden_size=dim,
# head_size=dim // config.num_attention_heads,
# total_num_heads=config.num_attention_heads,
# bias=config.qkv_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv")
self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias)
# self.proj = RowParallelLinear(input_size=dim,
# output_size=dim,
# bias = config.use_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.proj")

def forward( # todo might implement multiple attn implementations
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, C = x.shape
qkv, _ = self.qkv(x)
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5

self.qkv = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)

self.proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
bias=config.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)

self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

qkv = qkv.reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)

q, k, v = qkv.unbind(0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)

x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
x = self.attn(q, k, v)
x, _ = self.proj(x)
return x

Expand All @@ -141,37 +152,40 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix=f"{prefix}.mlp")
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm_1.forward_native(x), mask)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm_1.forward_native(x))
x = x + self.mlp(self.norm_2.forward_native(x))
return x


class AIMv2Transformer(nn.Module):

def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
def __init__(
self,
config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
prefix: str = "",
):
super().__init__()

self.blocks = nn.ModuleList([
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
for i in range(config.num_hidden_layers)
])
self.post_trunk_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
if require_post_norm:
self.post_trunk_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
else:
self.post_trunk_norm = None

def forward(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
# they take the -1 as the ref embeddings, like a clip skip
for block in self.blocks:
tokens = block(tokens, mask)
# NO NORM IN THE OG IMPLEMENTATION
# tokens = self.post_trunk_norm(tokens)
tokens = block(tokens)
if self.post_trunk_norm is not None:
tokens = self.post_trunk_norm(tokens)
return tokens


Expand All @@ -180,20 +194,52 @@ class AIMv2Model(torch.nn.Module):
def __init__(self,
config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: Optional[bool] = None,
prefix: str = ""):
super().__init__()
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config,
quant_config=quant_config,
require_post_norm=require_post_norm,
prefix=f"{prefix}.trunk")

def forward(
self,
pixel_values: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:

x = self.preprocessor(pixel_values)
x = self.trunk(x, mask)
x = self.trunk(x)

return x

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".fc13", ".fc1", 0),
(".fc13", ".fc3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()

for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if (name.startswith("trunk.post_trunk_norm")
and self.trunk.post_trunk_norm is None):
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
5 changes: 0 additions & 5 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__(
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout

self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
Expand All @@ -129,10 +128,6 @@ def __init__(
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Loading