Skip to content

Commit b9528e6

Browse files
[0.7.3] optimize qwen2_vl and qwen2_5_vl (#702)
### What this PR does / why we need it? Optimize qwen2_vl and qwen2_5_vl. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Testing this PR on 1080p picture with tp=1, bs=1 on Qwen2-VL and Qwen2.5-VL, every fa op's during time lasting from 11ms to 9ms, got roughly 22% perf boost. --------- Signed-off-by: zouyida2052 <zouyida@huawei.com> Signed-off-by: zouyida2052 <zouyida2002@gmail.com> Co-authored-by: zouyida2052 <zouyida@huawei.com>
1 parent 1791113 commit b9528e6

File tree

3 files changed

+187
-35
lines changed

3 files changed

+187
-35
lines changed

vllm_ascend/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
def register_model():
55
ModelRegistry.register_model(
66
"Qwen2VLForConditionalGeneration",
7-
"vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration")
7+
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
88

99
ModelRegistry.register_model(
1010
"Qwen2_5_VLForConditionalGeneration",

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
from vllm.model_executor.models.utils import maybe_prefix
4343
from vllm.multimodal import MULTIMODAL_REGISTRY
4444

45-
MIN_PAD_SIZE = 64
46-
MAX_PAD_SIZE = 128
45+
MIN_PAD_SIZE = 64 # min_size to pad weight
46+
MAX_PAD_SIZE = 128 # max_size to pad weight
4747

4848

4949
class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention):
@@ -66,6 +66,7 @@ def __init__(
6666
self.embed_dim = embed_dim
6767
self.hidden_size_per_attention_head = dist_utils.divide(
6868
projection_size, num_heads)
69+
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
6970
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
7071
self.hidden_size_per_attention_head = MAX_PAD_SIZE
7172

@@ -101,7 +102,7 @@ def forward(
101102
key=k,
102103
value=v,
103104
seq_len=cu_seqlens,
104-
scale_value=self.hidden_size_per_attention_head**-0.5,
105+
scale_value=self.origin_hidden_size_per_attention_head**-0.5,
105106
num_heads=self.num_attention_heads_per_partition,
106107
num_kv_heads=self.num_attention_heads_per_partition,
107108
out=context_layer)
@@ -164,6 +165,7 @@ def __init__(
164165
super().__init__(vision_config, norm_eps, quant_config, prefix)
165166
norm_layer = partial(RMSNorm, eps=norm_eps)
166167
self.interleaved = interleaved
168+
self.enable_pad = False
167169
self.patch_embed = AscendQwen2_5_VisionPatchEmbed(
168170
patch_size=vision_config.patch_size,
169171
temporal_patch_size=vision_config.temporal_patch_size,
@@ -187,6 +189,7 @@ def __init__(
187189
self.hidden_size, self.num_heads)
188190

189191
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
192+
self.enable_pad = True
190193
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
191194
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
192195
self.half_pad_hidden_size_per_attention_head = (
@@ -196,10 +199,11 @@ def __init__(
196199
def cal_cos_sin(self, rotary_pos_emb):
197200
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
198201
sin = rotary_pos_emb.sin()
199-
cos = torch.nn.functional.pad(
200-
cos, (0, self.half_pad_hidden_size_per_attention_head))
201-
sin = torch.nn.functional.pad(
202-
sin, (0, self.half_pad_hidden_size_per_attention_head))
202+
if self.enable_pad:
203+
cos = torch.nn.functional.pad(
204+
cos, (0, self.half_pad_hidden_size_per_attention_head))
205+
sin = torch.nn.functional.pad(
206+
sin, (0, self.half_pad_hidden_size_per_attention_head))
203207

204208
if not self.interleaved:
205209
cos_new = torch.cat((cos, cos), dim=-1)
@@ -285,11 +289,11 @@ def load_weights(self, weights: Iterable[Tuple[str,
285289
weight_loader = getattr(param, "weight_loader",
286290
default_weight_loader)
287291
weight_loader(param, loaded_weight)
288-
if ("attn.proj.weight" in name):
292+
if ("attn.proj.weight" in name) and self.enable_pad:
289293
param.data = self.pad_proj_weight(param.data)
290-
if ("attn.qkv.weight" in name):
294+
if ("attn.qkv.weight" in name) and self.enable_pad:
291295
param.data = self.pad_qkv_weight(param.data)
292-
if ("attn.qkv.bias" in name):
296+
if ("attn.qkv.bias" in name) and self.enable_pad:
293297
param.data = self.pad_qkv_bias(param.data)
294298
loaded_params.add(name)
295299
return loaded_params

0 commit comments

Comments
 (0)