2525from vllm .model_executor .layers .layernorm import RMSNorm
2626from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
2727 QKVParallelLinear ,
28- ReplicatedLinear ,
2928 RowParallelLinear )
3029from vllm .model_executor .layers .quantization import QuantizationConfig
3130from vllm .model_executor .model_loader .weight_utils import default_weight_loader
@@ -164,23 +163,15 @@ def __init__(
164163 self .tp_size )
165164
166165 self .scale = self .head_dim ** - 0.5
167- if use_data_parallel :
168- self .qkv = ReplicatedLinear (
169- self .embed_dim ,
170- 3 * self .head_dim * self .num_heads ,
171- bias = config .qkv_bias ,
172- quant_config = quant_config ,
173- prefix = f"{ prefix } .qkv" ,
174- )
175- else :
176- self .qkv = QKVParallelLinear (
177- self .embed_dim ,
178- self .head_dim ,
179- num_dummy_heads + self .num_heads ,
180- bias = config .qkv_bias ,
181- quant_config = quant_config ,
182- prefix = f"{ prefix } .qkv" ,
183- )
166+ self .qkv = QKVParallelLinear (
167+ self .embed_dim ,
168+ self .head_dim ,
169+ num_dummy_heads + self .num_heads ,
170+ bias = config .qkv_bias ,
171+ quant_config = quant_config ,
172+ prefix = f"{ prefix } .qkv" ,
173+ disable_tp = use_data_parallel ,
174+ )
184175
185176 self .qk_normalization = config .qk_normalization
186177
@@ -192,20 +183,13 @@ def __init__(
192183 eps = config .layer_norm_eps ,
193184 var_hidden_size = self .embed_dim )
194185
195- if use_data_parallel :
196- self .proj = ReplicatedLinear (
197- self .dummy_dim ,
198- self .embed_dim ,
199- quant_config = quant_config ,
200- prefix = f"{ prefix } .proj" ,
201- )
202- else :
203- self .proj = RowParallelLinear (
204- self .dummy_dim ,
205- self .embed_dim ,
206- quant_config = quant_config ,
207- prefix = f"{ prefix } .proj" ,
208- )
186+ self .proj = RowParallelLinear (
187+ self .dummy_dim ,
188+ self .embed_dim ,
189+ quant_config = quant_config ,
190+ prefix = f"{ prefix } .proj" ,
191+ disable_tp = use_data_parallel ,
192+ )
209193
210194 self .attn = MultiHeadAttention (self .num_heads_per_partition ,
211195 self .head_dim , self .scale )
@@ -236,72 +220,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
236220 return out
237221
238222
239- class InternSdpaAttention (nn .Module ):
240- """Multi-headed attention from 'Attention Is All You Need' paper"""
241-
242- def __init__ (
243- self ,
244- config : PretrainedConfig ,
245- * ,
246- num_dummy_heads : int = 0 ,
247- ) -> None :
248- super ().__init__ ()
249-
250- self .config = config
251- self .embed_dim = config .hidden_size
252- self .num_heads = config .num_attention_heads
253- self .head_dim = self .embed_dim // self .num_heads
254- if self .head_dim * self .num_heads != self .embed_dim :
255- raise ValueError (
256- f'embed_dim must be divisible by num_heads '
257- f'(got `embed_dim`: { self .embed_dim } and `num_heads`:'
258- f' { self .num_heads } ).' )
259-
260- # Additional dummy heads are used to enable TP for common GPU counts.
261- self .dummy_dim = (num_dummy_heads + self .num_heads ) * self .head_dim
262-
263- self .scale = self .head_dim ** - 0.5
264- self .qkv = nn .Linear (self .embed_dim ,
265- 3 * self .dummy_dim ,
266- bias = config .qkv_bias )
267-
268- self .qk_normalization = config .qk_normalization
269-
270- if self .qk_normalization :
271- self .q_norm = RMSNorm (self .dummy_dim ,
272- eps = config .layer_norm_eps ,
273- var_hidden_size = self .embed_dim )
274- self .k_norm = RMSNorm (self .dummy_dim ,
275- eps = config .layer_norm_eps ,
276- var_hidden_size = self .embed_dim )
277-
278- self .proj = nn .Linear (self .dummy_dim , self .embed_dim )
279-
280- # Use unified MultiHeadAttention with automatic backend selection
281- self .attn = MultiHeadAttention (self .num_heads , self .head_dim ,
282- self .scale )
283-
284- def forward (self , x : torch .Tensor ) -> torch .Tensor :
285- B , N , C = x .shape
286- qkv = self .qkv (x )
287- q , k , v = qkv .chunk (3 , dim = - 1 )
288-
289- q = q .view (B , N , self .num_heads , self .head_dim )
290- k = k .view (B , N , self .num_heads , self .head_dim )
291- v = v .view (B , N , self .num_heads , self .head_dim )
292-
293- if self .qk_normalization :
294- B_ , N_ , H_ , D_ = q .shape
295- q = self .q_norm (q .flatten (- 2 , - 1 )).view (B_ , N_ , H_ , D_ )
296- k = self .k_norm (k .flatten (- 2 , - 1 )).view (B_ , N_ , H_ , D_ )
297-
298- # Use unified MultiHeadAttention with automatic backend selection
299- x = self .attn (q , k , v )
300-
301- x = self .proj (x )
302- return x
303-
304-
305223class InternMLP (nn .Module ):
306224
307225 def __init__ (
@@ -315,20 +233,18 @@ def __init__(
315233
316234 self .config = config
317235 self .activation_fn = get_act_fn (config .hidden_act )
318- cls_fc1 = (ReplicatedLinear
319- if use_data_parallel else ColumnParallelLinear )
320- self .fc1 = cls_fc1 (config .hidden_size ,
321- config .intermediate_size ,
322- bias = True ,
323- quant_config = quant_config ,
324- prefix = f"{ prefix } .fc1" )
325- cls_fc2 = (ReplicatedLinear
326- if use_data_parallel else RowParallelLinear )
327- self .fc2 = cls_fc2 (config .intermediate_size ,
328- config .hidden_size ,
329- bias = True ,
330- quant_config = quant_config ,
331- prefix = f"{ prefix } .fc2" )
236+ self .fc1 = ColumnParallelLinear (config .hidden_size ,
237+ config .intermediate_size ,
238+ bias = True ,
239+ quant_config = quant_config ,
240+ prefix = f"{ prefix } .fc1" ,
241+ disable_tp = use_data_parallel )
242+ self .fc2 = RowParallelLinear (config .intermediate_size ,
243+ config .hidden_size ,
244+ bias = True ,
245+ quant_config = quant_config ,
246+ prefix = f"{ prefix } .fc2" ,
247+ disable_tp = use_data_parallel )
332248
333249 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
334250 hidden_states , _ = self .fc1 (hidden_states )
@@ -385,19 +301,19 @@ def _init_attn(
385301 use_data_parallel : bool = False ,
386302 ):
387303 # fallback to sdpa attention if tp unavailable
388- # tp_size = get_tensor_model_parallel_world_size()
389304 tp_size = (1 if use_data_parallel else
390305 get_tensor_model_parallel_world_size ())
391306 num_heads = config .num_attention_heads
392307
393- if (num_heads + num_dummy_heads ) % tp_size == 0 :
394- return InternParallelAttention (config ,
395- quant_config = quant_config ,
396- num_dummy_heads = num_dummy_heads ,
397- prefix = prefix ,
398- use_data_parallel = use_data_parallel )
399-
400- return InternSdpaAttention (config , num_dummy_heads = num_dummy_heads )
308+ # if the number of heads is not divisible by tp_size,
309+ # we also disable Attention's TP
310+ use_data_parallel = (use_data_parallel
311+ or (num_heads + num_dummy_heads ) % tp_size != 0 )
312+ return InternParallelAttention (config ,
313+ quant_config = quant_config ,
314+ num_dummy_heads = num_dummy_heads ,
315+ prefix = prefix ,
316+ use_data_parallel = use_data_parallel )
401317
402318 def forward (
403319 self ,
0 commit comments