11from collections .abc import Iterable
2- from typing import Optional , Union
2+ from typing import Any , Optional , Union
33
44import torch
55import torch .nn .functional as F
66import vllm .envs as envs
77from torch import nn
88from transformers import Qwen2Config
9+ from vllm .attention import AttentionType
910from vllm .compilation .decorators import support_torch_compile
1011from vllm .config import CacheConfig , VllmConfig
1112from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
1415 tensor_model_parallel_all_reduce ,
1516 tensor_model_parallel_reduce_scatter )
1617from vllm .forward_context import get_forward_context
18+ from vllm .model_executor .layers .layernorm import RMSNorm
1719from vllm .model_executor .layers .logits_processor import LogitsProcessor
1820from vllm .model_executor .layers .quantization import QuantizationConfig
21+ from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
1922from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
2023from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
21- from vllm .model_executor .models .qwen2 import Qwen2DecoderLayer , Qwen2Model
24+ from vllm .model_executor .models .qwen2 import (Qwen2Attention , Qwen2MLP ,
25+ Qwen2Model )
2226from vllm .model_executor .models .utils import (AutoWeightsLoader ,
2327 PPMissingLayer , maybe_prefix )
2428from vllm .model_executor .sampling_metadata import SamplingMetadata
@@ -48,7 +52,59 @@ def maybe_pad_and_reduce_scatter(
4852 return hidden_states
4953
5054
51- class CustomQwen2DecoderLayer (Qwen2DecoderLayer ):
55+ class CustomQwen2Attention (Qwen2Attention ):
56+
57+ def __init__ (
58+ self ,
59+ hidden_size : int ,
60+ num_heads : int ,
61+ num_kv_heads : int ,
62+ max_position : int = 4096 * 32 ,
63+ rope_theta : float = 10000 ,
64+ cache_config : Optional [CacheConfig ] = None ,
65+ quant_config : Optional [QuantizationConfig ] = None ,
66+ rope_scaling : Optional [tuple ] = None ,
67+ prefix : str = "" ,
68+ attn_type : str = AttentionType .DECODER ,
69+ dual_chunk_attention_config : Optional [dict [str , Any ]] = None ,
70+ ) -> None :
71+ super ().__init__ (
72+ hidden_size = hidden_size ,
73+ num_heads = num_heads ,
74+ num_kv_heads = num_kv_heads ,
75+ max_position = max_position ,
76+ rope_theta = rope_theta ,
77+ cache_config = cache_config ,
78+ quant_config = quant_config ,
79+ rope_scaling = rope_scaling ,
80+ prefix = prefix ,
81+ attn_type = attn_type ,
82+ dual_chunk_attention_config = dual_chunk_attention_config )
83+
84+ def forward (self ,
85+ positions : torch .Tensor ,
86+ hidden_states : torch .Tensor ,
87+ cos : Optional [torch .Tensor ] = None ,
88+ sin : Optional [torch .Tensor ] = None ) -> torch .Tensor :
89+ qkv , _ = self .qkv_proj (hidden_states )
90+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
91+ if type (self .rotary_emb ) is RotaryEmbedding :
92+ # We optimized RotaryEmbedding by moving index_select of cos & sin outside.
93+ # if cos & sin are provided, set is_cos_sin_cached to True to skip index_select.
94+ q , k = self .rotary_emb (positions ,
95+ q ,
96+ k ,
97+ cos = cos ,
98+ sin = sin ,
99+ is_cos_sin_cached = True )
100+ else :
101+ q , k = self .rotary_emb (positions , q , k )
102+ attn_output = self .attn (q , k , v )
103+ output , _ = self .o_proj (attn_output )
104+ return output
105+
106+
107+ class CustomQwen2DecoderLayer (nn .Module ):
52108
53109 def __init__ (
54110 self ,
@@ -57,10 +113,49 @@ def __init__(
57113 quant_config : Optional [QuantizationConfig ] = None ,
58114 prefix : str = "" ,
59115 ) -> None :
60- super ().__init__ (config = config ,
61- cache_config = cache_config ,
62- quant_config = quant_config ,
63- prefix = prefix )
116+ super ().__init__ ()
117+ self .hidden_size = config .hidden_size
118+ # Requires transformers > 4.32.0
119+ rope_theta = getattr (config , "rope_theta" , 1000000 )
120+ rope_scaling = getattr (config , "rope_scaling" , None )
121+ dual_chunk_attention_config = getattr (config ,
122+ "dual_chunk_attention_config" ,
123+ None )
124+
125+ # By default, Qwen2 uses causal attention as it is a decoder-only model.
126+ # You can override the HF config with `is_causal=False` to enable
127+ # bidirectional attention, which is used in some embedding models
128+ # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
129+ if getattr (config , "is_causal" , True ):
130+ attn_type = AttentionType .DECODER
131+ else :
132+ attn_type = AttentionType .ENCODER_ONLY
133+
134+ self .self_attn = CustomQwen2Attention (
135+ hidden_size = self .hidden_size ,
136+ num_heads = config .num_attention_heads ,
137+ max_position = config .max_position_embeddings ,
138+ num_kv_heads = config .num_key_value_heads ,
139+ rope_theta = rope_theta ,
140+ cache_config = cache_config ,
141+ quant_config = quant_config ,
142+ rope_scaling = rope_scaling ,
143+ prefix = f"{ prefix } .self_attn" ,
144+ attn_type = attn_type ,
145+ dual_chunk_attention_config = dual_chunk_attention_config ,
146+ )
147+ self .mlp = Qwen2MLP (
148+ hidden_size = self .hidden_size ,
149+ intermediate_size = config .intermediate_size ,
150+ hidden_act = config .hidden_act ,
151+ quant_config = quant_config ,
152+ prefix = f"{ prefix } .mlp" ,
153+ )
154+ self .input_layernorm = RMSNorm (config .hidden_size ,
155+ eps = config .rms_norm_eps )
156+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
157+ eps = config .rms_norm_eps )
158+
64159 self .tp_rank = get_tensor_model_parallel_rank ()
65160 self .tp_size = get_tensor_model_parallel_world_size ()
66161 self .self_attn .o_proj .reduce_results = False
@@ -73,6 +168,8 @@ def forward(
73168 residual : Optional [torch .Tensor ],
74169 flashcomm_v1_enabled : bool ,
75170 pad_size : int ,
171+ cos : Optional [torch .Tensor ] = None ,
172+ sin : Optional [torch .Tensor ] = None ,
76173 ) -> tuple [torch .Tensor , torch .Tensor ]:
77174 # Self Attention
78175 if residual is None :
@@ -89,10 +186,10 @@ def forward(
89186 if flashcomm_v1_enabled :
90187 hidden_states = all_gather_and_maybe_unpad (
91188 hidden_states , pad_size )
92- hidden_states = self .self_attn (
93- positions = positions ,
94- hidden_states = hidden_states ,
95- )
189+ hidden_states = self .self_attn (positions = positions ,
190+ hidden_states = hidden_states ,
191+ cos = cos ,
192+ sin = sin )
96193 if flashcomm_v1_enabled :
97194 hidden_states = maybe_pad_and_reduce_scatter (
98195 hidden_states , pad_size )
@@ -133,6 +230,7 @@ def __init__(
133230 prefix = prefix ,
134231 decoder_layer_type = decoder_layer_type )
135232 self .tp_size = get_tensor_model_parallel_world_size ()
233+ self .cos_sin_cache = self .layers [0 ].self_attn .rotary_emb .cos_sin_cache
136234
137235 def forward (
138236 self ,
@@ -163,13 +261,28 @@ def forward(
163261 num_tokens = hidden_states .size (0 )
164262 pad_size = (self .tp_size -
165263 (num_tokens % self .tp_size )) % self .tp_size
264+
265+ # Generate cos and sin outside layers to avoid repeated calculation.
266+ cos , sin = None , None
267+ if type (self .layers [0 ].self_attn .rotary_emb ) is RotaryEmbedding :
268+ cos_sin = self .cos_sin_cache .index_select (0 , positions )
269+ last_dim = cos_sin .size ()[- 1 ]
270+ cos , sin = cos_sin .reshape (- 1 , 2 ,
271+ last_dim // 2 ).repeat (1 , 1 ,
272+ 2 ).chunk (2 ,
273+ dim = - 2 )
274+ cos , sin = cos .view (1 , - 1 , 1 , last_dim ).contiguous (), sin .view (
275+ 1 , - 1 , 1 , last_dim ).contiguous ()
276+
166277 for layer in self .layers [self .start_layer :self .end_layer ]:
167278 hidden_states , residual = layer (
168279 positions ,
169280 hidden_states ,
170281 residual ,
171282 flashcomm_v1_enabled ,
172283 pad_size ,
284+ cos = cos ,
285+ sin = sin ,
173286 )
174287 if not get_pp_group ().is_last_rank :
175288 return IntermediateTensors ({
0 commit comments