4040
4141from vllm .attention import AttentionMetadata
4242from vllm .config import VllmConfig
43- from vllm .distributed import parallel_state
43+ from vllm .distributed import parallel_state , tensor_model_parallel_all_gather
4444from vllm .distributed import utils as dist_utils
4545from vllm .logger import init_logger
4646from vllm .model_executor import SamplingMetadata
@@ -207,11 +207,12 @@ def __init__(
207207 ) -> None :
208208 super ().__init__ ()
209209 # Per attention head and per partition values.
210- world_size = parallel_state .get_tensor_model_parallel_world_size ()
210+ self .tp_size = parallel_state .get_tensor_model_parallel_world_size ()
211+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
211212 self .hidden_size_per_attention_head = dist_utils .divide (
212213 projection_size , num_heads )
213214 self .num_attention_heads_per_partition = dist_utils .divide (
214- num_heads , world_size )
215+ num_heads , self . tp_size )
215216
216217 self .qkv = ColumnParallelLinear (input_size = embed_dim ,
217218 output_size = 3 * projection_size ,
@@ -231,6 +232,29 @@ def __init__(
231232 f"Qwen2.5-VL does not support { self .attn_backend } backend now."
232233 )
233234
235+ def split_qkv (self , qkv : torch .Tensor ) -> tuple [torch .Tensor , ...]:
236+ # [s, b, 3 * head * head_dim]
237+ seq_len , bs , _ = qkv .shape
238+ if self .tp_size > 1 :
239+ qkv = tensor_model_parallel_all_gather (qkv )
240+
241+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
242+ q , k , v = qkv .chunk (3 , dim = 2 )
243+
244+ # 3 * [s, b, head * head_dim]
245+ if self .tp_size > 1 :
246+ splitter = partial (dist_utils .split_tensor_along_last_dim ,
247+ num_partitions = self .tp_size )
248+ q = splitter (q )[self .tp_rank ]
249+ k = splitter (k )[self .tp_rank ]
250+ v = splitter (v )[self .tp_rank ]
251+
252+ # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
253+ new_shape = (seq_len , bs , self .num_attention_heads_per_partition ,
254+ self .hidden_size_per_attention_head )
255+ q , k , v = (x .view (* new_shape ) for x in (q , k , v ))
256+ return q , k , v
257+
234258 def forward (
235259 self ,
236260 x : torch .Tensor ,
@@ -240,15 +264,8 @@ def forward(
240264 # [s, b, c] --> [s, b, head * 3 * head_dim]
241265 x , _ = self .qkv (x )
242266
243- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
244- new_x_shape = x .size ()[:- 1 ] + (
245- self .num_attention_heads_per_partition ,
246- 3 * self .hidden_size_per_attention_head ,
247- )
248- x = x .view (* new_x_shape )
249-
250- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
251- q , k , v = dist_utils .split_tensor_along_last_dim (x , 3 )
267+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
268+ q , k , v = self .split_qkv (x )
252269 batch_size = q .shape [1 ]
253270
254271 q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
@@ -665,24 +682,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
665682 weight_loader (param , loaded_weight , shard_id )
666683 break
667684 else :
668- if name .endswith ("qkv.weight" ):
669- visual_num_heads = self .num_heads
670- visual_embed_dim = self .hidden_size
671- head_size = visual_embed_dim // visual_num_heads
672- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
673- head_size ,
674- visual_embed_dim )
675- loaded_weight = loaded_weight .transpose (0 , 1 )
676- loaded_weight = loaded_weight .reshape (- 1 , visual_embed_dim )
677- elif name .endswith ("qkv.bias" ):
678- visual_num_heads = self .num_heads
679- visual_embed_dim = self .hidden_size
680- head_size = visual_embed_dim // visual_num_heads
681- loaded_weight = loaded_weight .view (3 , visual_num_heads ,
682- head_size )
683- loaded_weight = loaded_weight .transpose (0 , 1 )
684- loaded_weight = loaded_weight .reshape (- 1 )
685-
686685 param = params_dict [name ]
687686 weight_loader = getattr (param , "weight_loader" ,
688687 default_weight_loader )
0 commit comments