2424from lightllm .models .qwen2_vl .qwen2_visual import PatchEmbed , VisionRotaryEmbedding
2525from lightllm .models .vit .triton_kernel .flashattention_nopad import flash_attention_fwd
2626from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
27+ from lightllm .models .qwen2_vl .triton_kernel .rotary_pos_emb import apply_rotary_pos_emb_triton
2728
2829# adapted from
2930# https://github.com/huggingface/transformers/blob/
3031# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src
3132# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1
32- class Qwen2_5_VLVisionConfig (PretrainedConfig ):
33- model_type = "qwen2_5_vl"
34-
35- def __init__ (
36- self ,
37- depth = 32 ,
38- hidden_size = 3584 ,
39- hidden_act = "silu" ,
40- intermediate_size = 3420 ,
41- num_heads = 16 ,
42- in_channels = 3 ,
43- patch_size = 14 ,
44- spatial_merge_size = 2 ,
45- temporal_patch_size = 2 ,
46- tokens_per_second = 4 ,
47- window_size = 112 ,
48- out_hidden_size = 3584 ,
49- fullatt_block_indexes = [7 , 15 , 23 , 31 ],
50- ** kwargs ,
51- ):
52- super ().__init__ (** kwargs )
53-
54- self .depth = depth
55- self .hidden_size = hidden_size
56- self .hidden_act = hidden_act
57- self .intermediate_size = intermediate_size
58- self .num_heads = num_heads
59- self .in_channels = in_channels
60- self .patch_size = patch_size
61- self .spatial_merge_size = spatial_merge_size
62- self .temporal_patch_size = temporal_patch_size
63- self .tokens_per_second = tokens_per_second
64- self .window_size = window_size
65- self .fullatt_block_indexes = fullatt_block_indexes
66- self .out_hidden_size = out_hidden_size
67-
68-
6933class Qwen2RMSNorm (nn .Module ):
7034 def __init__ (self , hidden_size , eps = 1e-6 ):
7135 """
@@ -104,54 +68,46 @@ def forward(self, hidden_state):
10468 return self .down_proj (self .act_fn (self .gate_proj (hidden_state )) * self .up_proj (hidden_state ))
10569
10670
107- def rotate_half (x ):
108- """Rotates half the hidden dims of the input."""
109- x1 = x [..., : x .shape [- 1 ] // 2 ]
110- x2 = x [..., x .shape [- 1 ] // 2 :]
111- return torch .cat ((- x2 , x1 ), dim = - 1 )
112-
113-
114- def apply_rotary_pos_emb_vision (
115- q : torch .Tensor , k : torch .Tensor , cos : torch .Tensor , sin : torch .Tensor
116- ) -> Tuple [torch .Tensor , torch .Tensor ]:
117- orig_q_dtype = q .dtype
118- orig_k_dtype = k .dtype
119- q , k = q .float (), k .float ()
120- cos , sin = cos .unsqueeze (- 2 ).float (), sin .unsqueeze (- 2 ).float ()
121- q_embed = (q * cos ) + (rotate_half (q ) * sin )
122- k_embed = (k * cos ) + (rotate_half (k ) * sin )
123- q_embed = q_embed .to (orig_q_dtype )
124- k_embed = k_embed .to (orig_k_dtype )
125- return q_embed , k_embed
126-
127-
12871class Qwen2_5_VLVisionFlashAttention (nn .Module ):
12972 def __init__ (self , dim : int , num_heads : int = 16 ) -> None :
13073 super ().__init__ ()
13174 self .num_heads = num_heads
13275 self .head_dim = dim // num_heads
13376 self .qkv = nn .Linear (dim , dim * 3 , bias = True )
13477 self .proj = nn .Linear (dim , dim )
78+ try :
79+ from vllm .vllm_flash_attn .layers .rotary import apply_rotary_emb
80+
81+ self .has_vllm = True
82+ self .apply_rotary_emb = apply_rotary_emb
83+ except ImportError :
84+ print ("Failed to import _flash_attn_forward from hopper.flash_attn_interface." )
85+ self .has_vllm = False
86+ self .apply_rotary_emb = apply_rotary_pos_emb_triton
87+
88+ def apply_rotary_pos_emb_vision (self , t : torch .Tensor , freqs : torch .Tensor ) -> torch .Tensor :
89+ t_ = t .float ()
90+ cos = freqs .cos ()
91+ sin = freqs .sin ()
92+ output = self .apply_rotary_emb (t_ , cos , sin ).type_as (t )
93+ return output
13594
13695 def forward (
13796 self ,
13897 hidden_states : torch .Tensor ,
13998 cu_seqlens : torch .Tensor ,
99+ max_seqlen : int = 0 ,
140100 rotary_pos_emb : Optional [torch .Tensor ] = None ,
141- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
142101 ) -> torch .Tensor :
143102 seq_length = hidden_states .shape [0 ]
144103 q , k , v = self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
145- if position_embeddings is None :
146- emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
147- cos = emb .cos ()
148- sin = emb .sin ()
149- else :
150- cos , sin = position_embeddings
151- q , k = apply_rotary_pos_emb_vision (q , k , cos , sin )
104+ # if position_embeddings is None:
105+ # position_embeddings = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
106+ q = self .apply_rotary_pos_emb_vision (q .unsqueeze (0 ), rotary_pos_emb )
107+ k = self .apply_rotary_pos_emb_vision (k .unsqueeze (0 ), rotary_pos_emb )
108+ q = q .squeeze (0 )
109+ k = k .squeeze (0 )
152110
153- cu_seqlens = cu_seqlens .to (q .device , torch .int32 )
154- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
155111 attn_output = g_cache_manager .alloc_tensor (q .shape , q .dtype , device = q .device )
156112 flash_attention_fwd (q , k , v , attn_output , cu_seqlens , max_seqlen )
157113 attn_output = attn_output .reshape (seq_length , - 1 )
@@ -183,14 +139,14 @@ def forward(
183139 self ,
184140 hidden_states : torch .Tensor ,
185141 cu_seqlens : torch .Tensor ,
142+ max_seqlen : int = 0 ,
186143 rotary_pos_emb : Optional [torch .Tensor ] = None ,
187- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
188144 ) -> torch .Tensor :
189145 hidden_states = hidden_states + self .attn (
190146 self .norm1 (hidden_states ),
191147 cu_seqlens = cu_seqlens ,
148+ max_seqlen = max_seqlen ,
192149 rotary_pos_emb = rotary_pos_emb ,
193- position_embeddings = position_embeddings ,
194150 )
195151 hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
196152 return hidden_states
@@ -215,7 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
215171class Qwen2_5VLTransformer (nn .Module ):
216172 def __init__ (
217173 self ,
218- weight_dir ,
174+ kvargs ,
219175 depth = 32 ,
220176 hidden_size = 3584 ,
221177 hidden_act = "silu" ,
@@ -232,7 +188,13 @@ def __init__(
232188 ** kwargs ,
233189 ):
234190 super ().__init__ ()
235-
191+ self .weight_dir = kvargs ["weight_dir" ]
192+ self .data_type = kvargs .get ("data_type" , "bfloat16" )
193+ # self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])]
194+ # self.weight_dict = kvargs.get("weight_dict", None)
195+ # self.quant_type = kvargs.get("quant_type", None)
196+ # self.quant_cfg_path = kvargs.get("quant_cfg", None)
197+ # self.max_batch_size = kvargs.get("max_batch_size", 1)
236198 self .depth = depth
237199 self .hidden_size = hidden_size
238200 self .hidden_act = hidden_act
@@ -279,46 +241,42 @@ def __init__(
279241
280242 self .gradient_checkpointing = False
281243
282- processor_config_path = os .path .join (weight_dir , "preprocessor_config.json" )
244+ processor_config_path = os .path .join (self . weight_dir , "preprocessor_config.json" )
283245 with open (processor_config_path , "r" ) as f :
284246 processor_config_dict = json .load (f )
285247 self .processor = Qwen2VLImageProcessor (** processor_config_dict )
286248
287- self .device = self .get_device ()
288- self .dtype = self .get_dtype ()
289-
290- def get_dtype (self ) -> torch .dtype :
291- return self .blocks [0 ].mlp .down_proj .weight .dtype
292-
293- def get_device (self ) -> torch .device :
294- return self .blocks [0 ].mlp .down_proj .weight .device
249+ self ._init_datatype ()
250+ self .load_model (kvargs ["weight_dir" ])
251+ self .cuda ()
252+
253+ def _init_datatype (self ):
254+ if isinstance (self .data_type , torch .dtype ):
255+ return
256+ if self .data_type in ["fp16" , "float16" ]:
257+ self .data_type = torch .float16
258+ elif self .data_type in ["bf16" , "bfloat16" ]:
259+ self .data_type = torch .bfloat16
260+ elif self .data_type in ["fp32" , "float32" ]:
261+ self .data_type = torch .float32
262+ else :
263+ raise ValueError (f"Unsupport datatype { self .data_type } !" )
264+ return
295265
296266 def rot_pos_emb (self , grid_thw ):
297267 pos_ids = []
298- for t , h , w in grid_thw :
268+ s = self .spatial_merge_size
269+ for _ , h , w in grid_thw :
270+ pos_shape = (h // s , s , w // s , s )
299271 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
300- hpos_ids = hpos_ids .reshape (
301- h // self .spatial_merge_size ,
302- self .spatial_merge_size ,
303- w // self .spatial_merge_size ,
304- self .spatial_merge_size ,
305- )
306- hpos_ids = hpos_ids .permute (0 , 2 , 1 , 3 )
307- hpos_ids = hpos_ids .flatten ()
308-
309272 wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
310- wpos_ids = wpos_ids .reshape (
311- h // self .spatial_merge_size ,
312- self .spatial_merge_size ,
313- w // self .spatial_merge_size ,
314- self .spatial_merge_size ,
315- )
316- wpos_ids = wpos_ids .permute (0 , 2 , 1 , 3 )
317- wpos_ids = wpos_ids .flatten ()
318- pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
273+ hpos_ids = hpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
274+ wpos_ids = wpos_ids .reshape (pos_shape ).permute (0 , 2 , 1 , 3 ).flatten ()
275+
276+ pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ))
319277 pos_ids = torch .cat (pos_ids , dim = 0 )
320278 max_grid_size = grid_thw [:, 1 :].max ()
321- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
279+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size ). type ( torch . float32 )
322280 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
323281 return rotary_pos_emb
324282
@@ -365,14 +323,22 @@ def get_window_index(self, grid_thw):
365323
366324 def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
367325 hidden_states = self .patch_embed (hidden_states )
368- rotary_pos_emb = self .rot_pos_emb (grid_thw )
326+ rotary_pos_emb = self .rot_pos_emb (grid_thw ).to ("cuda" , non_blocking = True )
327+ cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
328+ dim = 0 , dtype = torch .int32
329+ )
330+ cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
331+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
332+ cu_seqlens = cu_seqlens .to ("cuda" , non_blocking = True )
333+
369334 window_index , cu_window_seqlens = self .get_window_index (grid_thw )
370335 cu_window_seqlens = torch .tensor (
371336 cu_window_seqlens ,
372337 device = hidden_states .device ,
373338 dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 ,
374339 )
375340 cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
341+ max_window_seqlen = (cu_window_seqlens [1 :] - cu_window_seqlens [:- 1 ]).max ().item ()
376342
377343 seq_len , _ = hidden_states .size ()
378344 hidden_states = hidden_states .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
@@ -381,40 +347,21 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
381347 rotary_pos_emb = rotary_pos_emb .reshape (seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
382348 rotary_pos_emb = rotary_pos_emb [window_index , :, :]
383349 rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
384- emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
385- position_embeddings = (emb .cos (), emb .sin ())
386-
387- cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
388- dim = 0 ,
389- # Select dtype based on the following factors:
390- # - FA2 requires that cu_seqlens_q must have dtype int32
391- # - torch.onnx.export requires that cu_seqlens_q must have same
392- # dtype as grid_thw
393- # See https://github.com/huggingface/transformers/pull/34852
394- # for more information
395- dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 ,
396- )
397- cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
398350
399351 for layer_num , blk in enumerate (self .blocks ):
400352 if layer_num in self .fullatt_block_indexes :
401353 cu_seqlens_now = cu_seqlens
354+ max_seqlen_now = max_seqlen
402355 else :
403356 cu_seqlens_now = cu_window_seqlens
404- if self .gradient_checkpointing and self .training :
405- hidden_states = self ._gradient_checkpointing_func (
406- blk .__call__ ,
407- hidden_states ,
408- cu_seqlens_now ,
409- None ,
410- position_embeddings ,
411- )
412- else :
413- hidden_states = blk (
414- hidden_states ,
415- cu_seqlens = cu_seqlens_now ,
416- position_embeddings = position_embeddings ,
417- )
357+ max_seqlen_now = max_window_seqlen
358+
359+ hidden_states = blk (
360+ hidden_states ,
361+ cu_seqlens = cu_seqlens_now ,
362+ max_seqlen = max_seqlen_now ,
363+ rotary_pos_emb = rotary_pos_emb ,
364+ )
418365
419366 hidden_states = self .merger (hidden_states )
420367 reverse_indices = torch .argsort (window_index )
@@ -428,19 +375,15 @@ def load_image(self, img: List[ImageItem]):
428375 image_data = read_shm (get_shm_name_data (img .uuid ))
429376 image_data = Image .open (BytesIO (image_data ))
430377 image_data = resize_image (image_data )
431- image_inputs = self .processor .preprocess (images = image_data , return_tensors = "pt" )
432- pixel_values = image_inputs ["pixel_values" ].to (dtype = torch .bfloat16 )
433- image_grid_thw = image_inputs ["image_grid_thw" ]
378+ pixel_values , image_grid_thw = self .processor .preprocess (image_data )
434379 elif isinstance (img , dict ):
435380 image_data = read_shm (get_shm_name_data (img ["uuid" ]))
436381 image_data = Image .open (BytesIO (image_data ))
437382 image_data = resize_image (image_data )
438- image_inputs = self .processor .preprocess (images = image_data , return_tensors = "pt" )
439- pixel_values = image_inputs ["pixel_values" ].to (dtype = torch .bfloat16 )
440- image_grid_thw = image_inputs ["image_grid_thw" ]
383+ pixel_values , image_grid_thw = self .processor .preprocess (image_data )
441384 else :
442385 raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
443- return pixel_values .to (dtype = self .get_dtype () ), image_grid_thw
386+ return pixel_values .to (dtype = self .data_type ), image_grid_thw
444387
445388 def load_model (self , weight_dir ):
446389
0 commit comments