4545from vllm .logger import init_logger
4646from vllm .model_executor import SamplingMetadata
4747from vllm .model_executor .layers .activation import _ACTIVATION_REGISTRY
48+ from vllm .model_executor .layers .layernorm import RMSNorm
4849from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
4950 RowParallelLinear )
5051from vllm .model_executor .layers .quantization import QuantizationConfig
@@ -271,8 +272,13 @@ def forward(
271272 q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
272273 for x in (q , k , v ))
273274 if rotary_pos_emb is not None :
274- q = apply_rotary_pos_emb_vision (q , rotary_pos_emb )
275- k = apply_rotary_pos_emb_vision (k , rotary_pos_emb )
275+ use_flash_attn = self .attn_backend == _Backend .FLASH_ATTN
276+ q = apply_rotary_pos_emb_vision (q ,
277+ rotary_pos_emb ,
278+ use_flash_attn = use_flash_attn )
279+ k = apply_rotary_pos_emb_vision (k ,
280+ rotary_pos_emb ,
281+ use_flash_attn = use_flash_attn )
276282
277283 if self .attn_backend == _Backend .FLASH_ATTN :
278284 # from vllm_flash_attn.flash_attn_interface import (
@@ -296,20 +302,23 @@ def forward(
296302 "(b s) ... -> b s ..." ,
297303 b = batch_size )
298304 elif self .attn_backend == _Backend .TORCH_SDPA :
299- seq_length = q .size (1 )
300- q , k , v = (rearrange (x , "b s h d -> b h s d" ) for x in [q , k , v ])
301- attention_mask = torch .zeros ([1 , seq_length , seq_length ],
302- device = q .device ,
303- dtype = torch .bool )
305+ # Execute attention entry by entry for speed & less VRAM.
306+ outputs = []
304307 for i in range (1 , len (cu_seqlens )):
305- attention_mask [..., cu_seqlens [i - 1 ]:cu_seqlens [i ],
306- cu_seqlens [i - 1 ]:cu_seqlens [i ]] = True
307- output = F .scaled_dot_product_attention (q ,
308- k ,
309- v ,
310- attention_mask ,
311- dropout_p = 0.0 )
312- context_layer = rearrange (output , "b h s d -> b s h d " )
308+ start_idx = cu_seqlens [i - 1 ]
309+ end_idx = cu_seqlens [i ]
310+ q_i = q [:, start_idx :end_idx ]
311+ k_i = k [:, start_idx :end_idx ]
312+ v_i = v [:, start_idx :end_idx ]
313+ q_i , k_i , v_i = (rearrange (x , "b s h d -> b h s d" )
314+ for x in [q_i , k_i , v_i ])
315+ output_i = F .scaled_dot_product_attention (q_i ,
316+ k_i ,
317+ v_i ,
318+ dropout_p = 0.0 )
319+ output_i = rearrange (output_i , "b h s d -> b s h d " )
320+ outputs .append (output_i )
321+ context_layer = torch .cat (outputs , dim = 1 )
313322 elif self .attn_backend == _Backend .XFORMERS :
314323 from xformers import ops as xops
315324 from xformers .ops .fmha .attn_bias import BlockDiagonalMask
@@ -327,25 +336,6 @@ def forward(
327336 return output
328337
329338
330- class Qwen2RMSNorm (nn .Module ):
331-
332- def __init__ (self , hidden_size , eps = 1e-6 ):
333- super ().__init__ ()
334- self .weight = nn .Parameter (torch .ones (hidden_size ))
335- self .variance_epsilon = eps
336-
337- def forward (self , hidden_states ):
338- input_dtype = hidden_states .dtype
339- hidden_states = hidden_states .to (torch .float32 )
340- variance = hidden_states .pow (2 ).mean (- 1 , keepdim = True )
341- hidden_states = hidden_states * torch .rsqrt (variance +
342- self .variance_epsilon )
343- return self .weight * hidden_states .to (input_dtype )
344-
345- def extra_repr (self ):
346- return f"{ tuple (self .weight .shape )} , eps={ self .variance_epsilon } "
347-
348-
349339class Qwen2_5_VisionBlock (nn .Module ):
350340
351341 def __init__ (
@@ -516,8 +506,7 @@ def __init__(
516506 hidden_size = self .hidden_size ,
517507 )
518508
519- # NOTE: We use torch native RMSNorm here for precision purposes.
520- norm_layer = partial (Qwen2RMSNorm , eps = norm_eps )
509+ norm_layer = partial (RMSNorm , eps = norm_eps )
521510 head_dim = self .hidden_size // self .num_heads
522511 self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 )
523512
0 commit comments