Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend #8061

Merged
merged 12 commits into from
Sep 3, 2024
25 changes: 20 additions & 5 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from xformers import ops as xops
from transformers.models.blip.modeling_blip import BlipAttention

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -21,6 +21,12 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
Expand Down Expand Up @@ -156,7 +162,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class BlipAttention(nn.Module):
class BlipParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -224,7 +230,7 @@ def forward(
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)

return attn_output
return attn_output, None


class BlipMLP(nn.Module):
Expand Down Expand Up @@ -261,7 +267,16 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = BlipAttention(config, quant_config=quant_config)
# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
Expand All @@ -272,7 +287,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down
28 changes: 22 additions & 6 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from xformers import ops as xops
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -22,6 +22,12 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False


def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
Expand Down Expand Up @@ -162,7 +168,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class CLIPAttention(nn.Module):
class CLIPParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -231,7 +237,7 @@ def forward(
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)

return attn_output
return attn_output, None


class CLIPMLP(nn.Module):
Expand Down Expand Up @@ -266,7 +272,13 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = CLIPAttention(config, quant_config=quant_config)
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
Expand All @@ -278,7 +290,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -365,6 +377,10 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
Expand All @@ -386,7 +402,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

Expand Down
79 changes: 70 additions & 9 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from xformers import ops as xops

from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand All @@ -21,6 +20,12 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False

NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
Expand Down Expand Up @@ -81,7 +86,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
return embeddings


class InternAttention(nn.Module):
class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -140,18 +145,67 @@ def forward(self, x):
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)

x = xops.memory_efficient_attention_forward(
q,
k,
v,
scale=self.scale,
)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)

x, _ = self.proj(x)
return x


class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')

self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
bias=config.qkv_bias)

self.qk_normalization = config.qk_normalization

if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)

self.proj = nn.Linear(self.embed_dim, self.embed_dim)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)

q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)

if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)

x = self.proj(x)
return x


class InternMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -187,7 +241,14 @@ def __init__(self,
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type

self.attn = InternAttention(config, quant_config=quant_config)
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
Expand Down
24 changes: 13 additions & 11 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
if "vision" not in name or self.vision_tower.shard_weight:
for (param_name, shard_name,
shard_id) in stacked_params_mapping:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
Expand Down
27 changes: 22 additions & 5 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from xformers import ops as xops
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -26,6 +26,12 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False


def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
Expand Down Expand Up @@ -219,7 +225,7 @@ def forward(self,
return embeddings


class SiglipAttention(nn.Module):
class SiglipParallelAttention(nn.Module):

def __init__(
self,
Expand Down Expand Up @@ -282,7 +288,7 @@ def forward(
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)

return attn_output
return attn_output, None


class SiglipMLP(nn.Module):
Expand Down Expand Up @@ -327,7 +333,14 @@ def __init__(
super().__init__()
self.embed_dim = config.hidden_size

self.self_attn = SiglipAttention(config, quant_config=quant_config)
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = SiglipSdpaAttention(config)

self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
Expand All @@ -344,7 +357,7 @@ def forward(
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -476,6 +489,10 @@ def __init__(
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

self.vision_model = SiglipVisionTransformer(
config,
quant_config,
Expand Down
Loading