4242from vllm .model_executor .models .utils import maybe_prefix
4343from vllm .multimodal import MULTIMODAL_REGISTRY
4444
45- MIN_PAD_SIZE = 64
46- MAX_PAD_SIZE = 128
45+ MIN_PAD_SIZE = 64 # min_size to pad weight
46+ MAX_PAD_SIZE = 128 # max_size to pad weight
4747
4848
4949class AscendQwen2_5_VisionAttention (Qwen2_5_VisionAttention ):
@@ -66,6 +66,7 @@ def __init__(
6666 self .embed_dim = embed_dim
6767 self .hidden_size_per_attention_head = dist_utils .divide (
6868 projection_size , num_heads )
69+ self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
6970 if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
7071 self .hidden_size_per_attention_head = MAX_PAD_SIZE
7172
@@ -101,7 +102,7 @@ def forward(
101102 key = k ,
102103 value = v ,
103104 seq_len = cu_seqlens ,
104- scale_value = self .hidden_size_per_attention_head ** - 0.5 ,
105+ scale_value = self .origin_hidden_size_per_attention_head ** - 0.5 ,
105106 num_heads = self .num_attention_heads_per_partition ,
106107 num_kv_heads = self .num_attention_heads_per_partition ,
107108 out = context_layer )
@@ -164,6 +165,7 @@ def __init__(
164165 super ().__init__ (vision_config , norm_eps , quant_config , prefix )
165166 norm_layer = partial (RMSNorm , eps = norm_eps )
166167 self .interleaved = interleaved
168+ self .enable_pad = False
167169 self .patch_embed = AscendQwen2_5_VisionPatchEmbed (
168170 patch_size = vision_config .patch_size ,
169171 temporal_patch_size = vision_config .temporal_patch_size ,
@@ -187,6 +189,7 @@ def __init__(
187189 self .hidden_size , self .num_heads )
188190
189191 if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
192+ self .enable_pad = True
190193 self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
191194 self .half_origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head // 2
192195 self .half_pad_hidden_size_per_attention_head = (
@@ -196,10 +199,11 @@ def __init__(
196199 def cal_cos_sin (self , rotary_pos_emb ):
197200 cos = rotary_pos_emb .cos () # [seqlen, rotary_dim / 2]
198201 sin = rotary_pos_emb .sin ()
199- cos = torch .nn .functional .pad (
200- cos , (0 , self .half_pad_hidden_size_per_attention_head ))
201- sin = torch .nn .functional .pad (
202- sin , (0 , self .half_pad_hidden_size_per_attention_head ))
202+ if self .enable_pad :
203+ cos = torch .nn .functional .pad (
204+ cos , (0 , self .half_pad_hidden_size_per_attention_head ))
205+ sin = torch .nn .functional .pad (
206+ sin , (0 , self .half_pad_hidden_size_per_attention_head ))
203207
204208 if not self .interleaved :
205209 cos_new = torch .cat ((cos , cos ), dim = - 1 )
@@ -285,11 +289,11 @@ def load_weights(self, weights: Iterable[Tuple[str,
285289 weight_loader = getattr (param , "weight_loader" ,
286290 default_weight_loader )
287291 weight_loader (param , loaded_weight )
288- if ("attn.proj.weight" in name ):
292+ if ("attn.proj.weight" in name ) and self . enable_pad :
289293 param .data = self .pad_proj_weight (param .data )
290- if ("attn.qkv.weight" in name ):
294+ if ("attn.qkv.weight" in name ) and self . enable_pad :
291295 param .data = self .pad_qkv_weight (param .data )
292- if ("attn.qkv.bias" in name ):
296+ if ("attn.qkv.bias" in name ) and self . enable_pad :
293297 param .data = self .pad_qkv_bias (param .data )
294298 loaded_params .add (name )
295299 return loaded_params
0 commit comments