Skip to content

Commit 3baaf62

Browse files
authored
[0.9.1]optmize rope in qwen2 (#1782)
### What this PR does / why we need it? Optimize rope by extracting index_select from layers into model, which can reduce (layer_num -1) * 2 Gather ops in each prefill/decode stage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass cos and sin and set skip_index_select=True to enable this optimization. As shown in the following code : `q, k = self.rotary_emb(positions, q, k, cos=cos, sin=sin, skip_index_select=True)` **Performance results:** **origin:** Successful requests: 400 Benchmark duration (s): 243.10 Total input tokens: 1200000 Total generated tokens: 60000 Request throughput (req/s): 1.65 Output token throughput (tok/s): 246.81 Total Token throughput (tok/s): 5183.02 **optimized:** Successful requests: 400 Benchmark duration (s): 237.42 Total input tokens: 1200000 Total generated tokens: 60000 Request throughput (req/s): 1.68 Output token throughput (tok/s): 252.72 Total Token throughput (tok/s): 5307.03 Signed-off-by: David9857 <985700846@qq.com>
1 parent 507dce5 commit 3baaf62

File tree

3 files changed

+129
-18
lines changed

3 files changed

+129
-18
lines changed

vllm_ascend/models/qwen2.py

Lines changed: 124 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from collections.abc import Iterable
2-
from typing import Optional, Union
2+
from typing import Any, Optional, Union
33

44
import torch
55
import torch.nn.functional as F
66
import vllm.envs as envs
77
from torch import nn
88
from transformers import Qwen2Config
9+
from vllm.attention import AttentionType
910
from vllm.compilation.decorators import support_torch_compile
1011
from vllm.config import CacheConfig, VllmConfig
1112
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@@ -14,11 +15,14 @@
1415
tensor_model_parallel_all_reduce,
1516
tensor_model_parallel_reduce_scatter)
1617
from vllm.forward_context import get_forward_context
18+
from vllm.model_executor.layers.layernorm import RMSNorm
1719
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1820
from vllm.model_executor.layers.quantization import QuantizationConfig
21+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
1922
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2023
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
21-
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model
24+
from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP,
25+
Qwen2Model)
2226
from vllm.model_executor.models.utils import (AutoWeightsLoader,
2327
PPMissingLayer, maybe_prefix)
2428
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -48,7 +52,59 @@ def maybe_pad_and_reduce_scatter(
4852
return hidden_states
4953

5054

51-
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
55+
class CustomQwen2Attention(Qwen2Attention):
56+
57+
def __init__(
58+
self,
59+
hidden_size: int,
60+
num_heads: int,
61+
num_kv_heads: int,
62+
max_position: int = 4096 * 32,
63+
rope_theta: float = 10000,
64+
cache_config: Optional[CacheConfig] = None,
65+
quant_config: Optional[QuantizationConfig] = None,
66+
rope_scaling: Optional[tuple] = None,
67+
prefix: str = "",
68+
attn_type: str = AttentionType.DECODER,
69+
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
70+
) -> None:
71+
super().__init__(
72+
hidden_size=hidden_size,
73+
num_heads=num_heads,
74+
num_kv_heads=num_kv_heads,
75+
max_position=max_position,
76+
rope_theta=rope_theta,
77+
cache_config=cache_config,
78+
quant_config=quant_config,
79+
rope_scaling=rope_scaling,
80+
prefix=prefix,
81+
attn_type=attn_type,
82+
dual_chunk_attention_config=dual_chunk_attention_config)
83+
84+
def forward(self,
85+
positions: torch.Tensor,
86+
hidden_states: torch.Tensor,
87+
cos: Optional[torch.Tensor] = None,
88+
sin: Optional[torch.Tensor] = None) -> torch.Tensor:
89+
qkv, _ = self.qkv_proj(hidden_states)
90+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
91+
if type(self.rotary_emb) is RotaryEmbedding:
92+
# We optimized RotaryEmbedding by moving index_select of cos & sin outside.
93+
# if cos & sin are provided, set is_cos_sin_cached to True to skip index_select.
94+
q, k = self.rotary_emb(positions,
95+
q,
96+
k,
97+
cos=cos,
98+
sin=sin,
99+
is_cos_sin_cached=True)
100+
else:
101+
q, k = self.rotary_emb(positions, q, k)
102+
attn_output = self.attn(q, k, v)
103+
output, _ = self.o_proj(attn_output)
104+
return output
105+
106+
107+
class CustomQwen2DecoderLayer(nn.Module):
52108

53109
def __init__(
54110
self,
@@ -57,10 +113,49 @@ def __init__(
57113
quant_config: Optional[QuantizationConfig] = None,
58114
prefix: str = "",
59115
) -> None:
60-
super().__init__(config=config,
61-
cache_config=cache_config,
62-
quant_config=quant_config,
63-
prefix=prefix)
116+
super().__init__()
117+
self.hidden_size = config.hidden_size
118+
# Requires transformers > 4.32.0
119+
rope_theta = getattr(config, "rope_theta", 1000000)
120+
rope_scaling = getattr(config, "rope_scaling", None)
121+
dual_chunk_attention_config = getattr(config,
122+
"dual_chunk_attention_config",
123+
None)
124+
125+
# By default, Qwen2 uses causal attention as it is a decoder-only model.
126+
# You can override the HF config with `is_causal=False` to enable
127+
# bidirectional attention, which is used in some embedding models
128+
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
129+
if getattr(config, "is_causal", True):
130+
attn_type = AttentionType.DECODER
131+
else:
132+
attn_type = AttentionType.ENCODER_ONLY
133+
134+
self.self_attn = CustomQwen2Attention(
135+
hidden_size=self.hidden_size,
136+
num_heads=config.num_attention_heads,
137+
max_position=config.max_position_embeddings,
138+
num_kv_heads=config.num_key_value_heads,
139+
rope_theta=rope_theta,
140+
cache_config=cache_config,
141+
quant_config=quant_config,
142+
rope_scaling=rope_scaling,
143+
prefix=f"{prefix}.self_attn",
144+
attn_type=attn_type,
145+
dual_chunk_attention_config=dual_chunk_attention_config,
146+
)
147+
self.mlp = Qwen2MLP(
148+
hidden_size=self.hidden_size,
149+
intermediate_size=config.intermediate_size,
150+
hidden_act=config.hidden_act,
151+
quant_config=quant_config,
152+
prefix=f"{prefix}.mlp",
153+
)
154+
self.input_layernorm = RMSNorm(config.hidden_size,
155+
eps=config.rms_norm_eps)
156+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
157+
eps=config.rms_norm_eps)
158+
64159
self.tp_rank = get_tensor_model_parallel_rank()
65160
self.tp_size = get_tensor_model_parallel_world_size()
66161
self.self_attn.o_proj.reduce_results = False
@@ -73,6 +168,8 @@ def forward(
73168
residual: Optional[torch.Tensor],
74169
flashcomm_v1_enabled: bool,
75170
pad_size: int,
171+
cos: Optional[torch.Tensor] = None,
172+
sin: Optional[torch.Tensor] = None,
76173
) -> tuple[torch.Tensor, torch.Tensor]:
77174
# Self Attention
78175
if residual is None:
@@ -89,10 +186,10 @@ def forward(
89186
if flashcomm_v1_enabled:
90187
hidden_states = all_gather_and_maybe_unpad(
91188
hidden_states, pad_size)
92-
hidden_states = self.self_attn(
93-
positions=positions,
94-
hidden_states=hidden_states,
95-
)
189+
hidden_states = self.self_attn(positions=positions,
190+
hidden_states=hidden_states,
191+
cos=cos,
192+
sin=sin)
96193
if flashcomm_v1_enabled:
97194
hidden_states = maybe_pad_and_reduce_scatter(
98195
hidden_states, pad_size)
@@ -133,6 +230,7 @@ def __init__(
133230
prefix=prefix,
134231
decoder_layer_type=decoder_layer_type)
135232
self.tp_size = get_tensor_model_parallel_world_size()
233+
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
136234

137235
def forward(
138236
self,
@@ -163,13 +261,28 @@ def forward(
163261
num_tokens = hidden_states.size(0)
164262
pad_size = (self.tp_size -
165263
(num_tokens % self.tp_size)) % self.tp_size
264+
265+
# Generate cos and sin outside layers to avoid repeated calculation.
266+
cos, sin = None, None
267+
if type(self.layers[0].self_attn.rotary_emb) is RotaryEmbedding:
268+
cos_sin = self.cos_sin_cache.index_select(0, positions)
269+
last_dim = cos_sin.size()[-1]
270+
cos, sin = cos_sin.reshape(-1, 2,
271+
last_dim // 2).repeat(1, 1,
272+
2).chunk(2,
273+
dim=-2)
274+
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
275+
1, -1, 1, last_dim).contiguous()
276+
166277
for layer in self.layers[self.start_layer:self.end_layer]:
167278
hidden_states, residual = layer(
168279
positions,
169280
hidden_states,
170281
residual,
171282
flashcomm_v1_enabled,
172283
pad_size,
284+
cos=cos,
285+
sin=sin,
173286
)
174287
if not get_pp_group().is_last_rank:
175288
return IntermediateTensors({

vllm_ascend/models/qwen3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def forward(
170170
k,
171171
cos=cos,
172172
sin=sin,
173-
skip_index_select=True)
173+
is_cos_sin_cached=True)
174174
attn_output = self.attn(q, k, v)
175175
pad_size = 0
176176
if self.enable_fc == 2:

vllm_ascend/ops/rotary_embedding.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def rope_forward_oot(
3939
cos: torch.Tensor = None,
4040
sin: torch.Tensor = None,
4141
is_neox_style_override: Optional[bool] = None,
42-
skip_index_select: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
42+
is_cos_sin_cached: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
4343
import torch_npu
4444
query_shape, key_shape = query.shape, key.shape
4545
if self.cos_sin_cache.device != query.device:
@@ -64,16 +64,14 @@ def rope_forward_oot(
6464
raise NotImplementedError(
6565
"Batched rotary embedding is currently not supported on NPU.")
6666
else:
67-
if skip_index_select and neox_style and self.head_size == self.rotary_dim:
68-
# TODO: Remove the contiguous in the future.
69-
# BSNH
67+
if is_cos_sin_cached and neox_style and self.head_size == self.rotary_dim and self.head_size == 128:
68+
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
69+
# This method requires head_size and rotary_dim equal 128 and neox_style is True
7070
query = query.contiguous().view(1, query.shape[0], -1,
7171
self.head_size)
7272
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
73-
# requires head_size=128 and neox_style=True
7473
torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
7574
else:
76-
# TODO: Remove the contiguous in the future.
7775
query = query.contiguous().view(query.shape[0], -1)
7876
key = key.contiguous().view(key.shape[0], -1)
7977
torch_npu._npu_rotary_embedding(

0 commit comments

Comments
 (0)