Skip to content

Commit

Permalink
[Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backe…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and Jeffwan committed Sep 19, 2024
1 parent 74e6769 commit 47e5dd4
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 44 deletions.
25 changes: 20 additions & 5 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from xformers import ops as xops
from transformers.models.blip.modeling_blip import BlipAttention

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -21,6 +21,12 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False


def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
Expand Down Expand Up @@ -156,7 +162,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class BlipAttention(nn.Module):
class BlipParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -224,7 +230,7 @@ def forward(
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)

return attn_output
return attn_output, None


class BlipMLP(nn.Module):
Expand Down Expand Up @@ -261,7 +267,16 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = BlipAttention(config, quant_config=quant_config)
# fallback to sdpa attention if tp unavailable
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = BlipParallelAttention(config,
quant_config=quant_config)
else:
# Blip doesn't have SDPA attention implemented in transformers
# use eager attention instead for cpu backend
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
Expand All @@ -272,7 +287,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down
28 changes: 22 additions & 6 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from xformers import ops as xops
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -22,6 +22,12 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False


def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
Expand Down Expand Up @@ -162,7 +168,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class CLIPAttention(nn.Module):
class CLIPParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -231,7 +237,7 @@ def forward(
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)

return attn_output
return attn_output, None


class CLIPMLP(nn.Module):
Expand Down Expand Up @@ -266,7 +272,13 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = CLIPAttention(config, quant_config=quant_config)
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = CLIPParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = CLIPSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
Expand All @@ -278,7 +290,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -365,6 +377,10 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
Expand All @@ -386,7 +402,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

Expand Down
79 changes: 70 additions & 9 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from xformers import ops as xops

from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand All @@ -21,6 +20,12 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False

NORM2FN = {
'rms_norm': RMSNorm,
'layer_norm': nn.LayerNorm,
Expand Down Expand Up @@ -81,7 +86,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
return embeddings


class InternAttention(nn.Module):
class InternParallelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
Expand Down Expand Up @@ -140,18 +145,67 @@ def forward(self, x):
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)

x = xops.memory_efficient_attention_forward(
q,
k,
v,
scale=self.scale,
)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)

x, _ = self.proj(x)
return x


class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')

self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
bias=config.qkv_bias)

self.qk_normalization = config.qk_normalization

if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)

self.proj = nn.Linear(self.embed_dim, self.embed_dim)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)

q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)

if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)

x = self.proj(x)
return x


class InternMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -187,7 +241,14 @@ def __init__(self,
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type

self.attn = InternAttention(config, quant_config=quant_config)
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
Expand Down
42 changes: 23 additions & 19 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,26 +307,30 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
if "vision" not in name or self.vision_tower.shard_weight:
for (param_name, shard_name,
shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True
else:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if "lm_head.weight" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
use_default_weight_loading = True

if use_default_weight_loading:
Expand Down
27 changes: 22 additions & 5 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from xformers import ops as xops
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -26,6 +26,12 @@
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False


def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
Expand Down Expand Up @@ -219,7 +225,7 @@ def forward(self,
return embeddings


class SiglipAttention(nn.Module):
class SiglipParallelAttention(nn.Module):

def __init__(
self,
Expand Down Expand Up @@ -282,7 +288,7 @@ def forward(
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)

return attn_output
return attn_output, None


class SiglipMLP(nn.Module):
Expand Down Expand Up @@ -327,7 +333,14 @@ def __init__(
super().__init__()
self.embed_dim = config.hidden_size

self.self_attn = SiglipAttention(config, quant_config=quant_config)
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = SiglipSdpaAttention(config)

self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
Expand All @@ -344,7 +357,7 @@ def forward(
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -476,6 +489,10 @@ def __init__(
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.vision_model = SiglipVisionTransformer(
config,
quant_config,
Expand Down

0 comments on commit 47e5dd4

Please sign in to comment.