Skip to content

Commit 02ed8a1

Browse files
authored
[Misc] Qwen2.5-VL Optimization (#13155)
1 parent 2092a6f commit 02ed8a1

File tree

2 files changed

+47
-51
lines changed

2 files changed

+47
-51
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from vllm.logger import init_logger
4646
from vllm.model_executor import SamplingMetadata
4747
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
48+
from vllm.model_executor.layers.layernorm import RMSNorm
4849
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
4950
RowParallelLinear)
5051
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -271,8 +272,13 @@ def forward(
271272
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
272273
for x in (q, k, v))
273274
if rotary_pos_emb is not None:
274-
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
275-
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
275+
use_flash_attn = self.attn_backend == _Backend.FLASH_ATTN
276+
q = apply_rotary_pos_emb_vision(q,
277+
rotary_pos_emb,
278+
use_flash_attn=use_flash_attn)
279+
k = apply_rotary_pos_emb_vision(k,
280+
rotary_pos_emb,
281+
use_flash_attn=use_flash_attn)
276282

277283
if self.attn_backend == _Backend.FLASH_ATTN:
278284
# from vllm_flash_attn.flash_attn_interface import (
@@ -296,20 +302,23 @@ def forward(
296302
"(b s) ... -> b s ...",
297303
b=batch_size)
298304
elif self.attn_backend == _Backend.TORCH_SDPA:
299-
seq_length = q.size(1)
300-
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
301-
attention_mask = torch.zeros([1, seq_length, seq_length],
302-
device=q.device,
303-
dtype=torch.bool)
305+
# Execute attention entry by entry for speed & less VRAM.
306+
outputs = []
304307
for i in range(1, len(cu_seqlens)):
305-
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
306-
cu_seqlens[i - 1]:cu_seqlens[i]] = True
307-
output = F.scaled_dot_product_attention(q,
308-
k,
309-
v,
310-
attention_mask,
311-
dropout_p=0.0)
312-
context_layer = rearrange(output, "b h s d -> b s h d ")
308+
start_idx = cu_seqlens[i - 1]
309+
end_idx = cu_seqlens[i]
310+
q_i = q[:, start_idx:end_idx]
311+
k_i = k[:, start_idx:end_idx]
312+
v_i = v[:, start_idx:end_idx]
313+
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
314+
for x in [q_i, k_i, v_i])
315+
output_i = F.scaled_dot_product_attention(q_i,
316+
k_i,
317+
v_i,
318+
dropout_p=0.0)
319+
output_i = rearrange(output_i, "b h s d -> b s h d ")
320+
outputs.append(output_i)
321+
context_layer = torch.cat(outputs, dim=1)
313322
elif self.attn_backend == _Backend.XFORMERS:
314323
from xformers import ops as xops
315324
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -327,25 +336,6 @@ def forward(
327336
return output
328337

329338

330-
class Qwen2RMSNorm(nn.Module):
331-
332-
def __init__(self, hidden_size, eps=1e-6):
333-
super().__init__()
334-
self.weight = nn.Parameter(torch.ones(hidden_size))
335-
self.variance_epsilon = eps
336-
337-
def forward(self, hidden_states):
338-
input_dtype = hidden_states.dtype
339-
hidden_states = hidden_states.to(torch.float32)
340-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
341-
hidden_states = hidden_states * torch.rsqrt(variance +
342-
self.variance_epsilon)
343-
return self.weight * hidden_states.to(input_dtype)
344-
345-
def extra_repr(self):
346-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
347-
348-
349339
class Qwen2_5_VisionBlock(nn.Module):
350340

351341
def __init__(
@@ -516,8 +506,7 @@ def __init__(
516506
hidden_size=self.hidden_size,
517507
)
518508

519-
# NOTE: We use torch native RMSNorm here for precision purposes.
520-
norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
509+
norm_layer = partial(RMSNorm, eps=norm_eps)
521510
head_dim = self.hidden_size // self.num_heads
522511
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
523512

vllm/model_executor/models/qwen2_vl.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,15 @@ def apply_rotary_emb_torch(x: torch.Tensor,
226226

227227

228228
def apply_rotary_pos_emb_vision(t: torch.Tensor,
229-
freqs: torch.Tensor) -> torch.Tensor:
229+
freqs: torch.Tensor,
230+
use_flash_attn=False) -> torch.Tensor:
230231
t_ = t.float()
231232
cos = freqs.cos()
232233
sin = freqs.sin()
233-
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
234+
apply_rotary_emb = apply_rotary_emb_torch
235+
if use_flash_attn:
236+
from flash_attn.layers.rotary import apply_rotary_emb
237+
output = apply_rotary_emb(t_, cos, sin).type_as(t)
234238
return output
235239

236240

@@ -336,20 +340,23 @@ def forward(
336340
"(b s) ... -> b s ...",
337341
b=batch_size)
338342
elif self.attn_backend == _Backend.TORCH_SDPA:
339-
seq_length = q.size(1)
340-
q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
341-
attention_mask = torch.zeros([1, seq_length, seq_length],
342-
device=q.device,
343-
dtype=torch.bool)
343+
# Execute attention entry by entry for speed & less VRAM.
344+
outputs = []
344345
for i in range(1, len(cu_seqlens)):
345-
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
346-
cu_seqlens[i - 1]:cu_seqlens[i]] = True
347-
output = F.scaled_dot_product_attention(q,
348-
k,
349-
v,
350-
attention_mask,
351-
dropout_p=0.0)
352-
context_layer = rearrange(output, "b h s d -> b s h d ")
346+
start_idx = cu_seqlens[i - 1]
347+
end_idx = cu_seqlens[i]
348+
q_i = q[:, start_idx:end_idx]
349+
k_i = k[:, start_idx:end_idx]
350+
v_i = v[:, start_idx:end_idx]
351+
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
352+
for x in [q_i, k_i, v_i])
353+
output_i = F.scaled_dot_product_attention(q_i,
354+
k_i,
355+
v_i,
356+
dropout_p=0.0)
357+
output_i = rearrange(output_i, "b h s d -> b s h d ")
358+
outputs.append(output_i)
359+
context_layer = torch.cat(outputs, dim=1)
353360
elif self.attn_backend == _Backend.XFORMERS:
354361
from xformers import ops as xops
355362
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

0 commit comments

Comments
 (0)