3030from vllm .forward_context import ForwardContext , get_forward_context
3131from vllm .logger import init_logger
3232from vllm .model_executor .layers .fla .ops import (
33+ RMSNormGated ,
3334 chunk_gated_delta_rule ,
3435 fused_recurrent_gated_delta_rule ,
3536)
3637from vllm .model_executor .layers .fused_moe import SharedFusedMoE
37- from vllm .model_executor .layers .layernorm import (
38- GemmaRMSNorm as Qwen3NextRMSNorm ,
39- )
40- from vllm .model_executor .layers .layernorm import RMSNormGated
38+ from vllm .model_executor .layers .layernorm import GemmaRMSNorm as Qwen3NextRMSNorm
4139from vllm .model_executor .layers .linear import (
4240 ColumnParallelLinear ,
4341 QKVParallelLinear ,
@@ -438,66 +436,17 @@ def forward(
438436 hidden_states : torch .Tensor ,
439437 output : torch .Tensor ,
440438 ):
441- """
442- Forward pass with three parts:
443- 1. Input projection
444- 2. Core attention (custom op)
445- 3. Output projection
446- """
447- num_tokens = hidden_states .size (0 )
448-
449- # ============================================================
450- # Part 1: Input Projection
451- # ============================================================
452- projected_states_qkvz , _ = self .in_proj_qkvz (hidden_states )
453- projected_states_ba , _ = self .in_proj_ba (hidden_states )
454- query , key , value , z , b , a = self .fix_query_key_value_ordering (
455- projected_states_qkvz , projected_states_ba
456- )
457- query , key , value = map (
458- lambda x : rearrange (x , "l p d -> l (p d)" ), (query , key , value )
459- )
460- mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
461-
462- # ============================================================
463- # Part 2: Core Attention (Custom Op)
464- # ============================================================
465- core_attn_out = torch .zeros (
466- (num_tokens , self .num_v_heads // self .tp_size , self .head_v_dim ),
467- dtype = hidden_states .dtype ,
468- device = hidden_states .device ,
469- )
470-
471- torch .ops .vllm .gdn_attention_core (
472- mixed_qkv ,
473- b ,
474- a ,
475- core_attn_out ,
439+ return torch .ops .vllm .gdn_attention (
440+ hidden_states ,
441+ output ,
476442 self .prefix ,
477443 )
478444
479- # ============================================================
480- # Part 3: Output Projection
481- # ============================================================
482- z_shape_og = z .shape
483- # Reshape input data into 2D tensor
484- core_attn_out = core_attn_out .reshape (- 1 , core_attn_out .shape [- 1 ])
485- z = z .reshape (- 1 , z .shape [- 1 ])
486- core_attn_out = self .norm (core_attn_out , z )
487- core_attn_out = core_attn_out .reshape (z_shape_og )
488- core_attn_out = rearrange (core_attn_out , "... h d -> ... (h d)" )
489- output [:num_tokens ], _ = self .out_proj (core_attn_out )
490-
491- def _forward_core (
445+ def _forward (
492446 self ,
493- mixed_qkv : torch .Tensor ,
494- b : torch .Tensor ,
495- a : torch .Tensor ,
496- core_attn_out : torch .Tensor ,
447+ hidden_states : torch .Tensor ,
448+ output : torch .Tensor ,
497449 ):
498- """
499- Core attention computation (called by custom op).
500- """
501450 forward_context = get_forward_context ()
502451 attn_metadata : AttentionMetadata = forward_context .attn_metadata
503452
@@ -522,11 +471,18 @@ def _forward_core(
522471 num_actual_tokens = attn_metadata .num_actual_tokens
523472 num_accepted_tokens = attn_metadata .num_accepted_tokens
524473
525- mixed_qkv = mixed_qkv [:num_actual_tokens ]
526- b = b [:num_actual_tokens ]
527- a = a [:num_actual_tokens ]
474+ # 1. Set up dimensions for reshapes later
475+ projected_states_qkvz , _ = self .in_proj_qkvz (hidden_states [:num_actual_tokens ])
476+ projected_states_ba , _ = self .in_proj_ba (hidden_states [:num_actual_tokens ])
477+ query , key , value , z , b , a = self .fix_query_key_value_ordering (
478+ projected_states_qkvz , projected_states_ba
479+ )
480+ query , key , value = map (
481+ lambda x : rearrange (x , "l p d -> l (p d)" ), (query , key , value )
482+ )
483+ mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
528484
529- # 1 . Convolution sequence transformation
485+ # 2 . Convolution sequence transformation
530486 conv_weights = self .conv1d .weight .view (
531487 self .conv1d .weight .size (0 ), self .conv1d .weight .size (2 )
532488 )
@@ -542,7 +498,7 @@ def _forward_core(
542498 mixed_qkv_spec = None
543499 mixed_qkv_non_spec = mixed_qkv
544500
545- # 1 .1: Process the multi -query part
501+ # 2 .1: process the mutli -query part
546502 if spec_sequence_masks is not None :
547503 mixed_qkv_spec = causal_conv1d_update (
548504 mixed_qkv_spec ,
@@ -559,7 +515,7 @@ def _forward_core(
559515 validate_data = False ,
560516 )
561517
562- # 1 .2: Process the remaining part
518+ # 2 .2: process the remaining part
563519 if attn_metadata .num_prefills > 0 :
564520 mixed_qkv_non_spec_T = mixed_qkv_non_spec .transpose (0 , 1 )
565521 # - "cache_indices" updates the conv_state cache in positions
@@ -617,9 +573,9 @@ def _forward_core(
617573 g_non_spec = g
618574 beta_non_spec = beta
619575
620- # 2 . Recurrent attention
576+ # 3 . Recurrent attention
621577
622- # 2 .1: Process the multi -query part
578+ # 3 .1: process the mutlti -query part
623579 if spec_sequence_masks is not None :
624580 core_attn_out_spec , last_recurrent_state = fused_recurrent_gated_delta_rule (
625581 q = query_spec ,
@@ -637,7 +593,7 @@ def _forward_core(
637593 else :
638594 core_attn_out_spec , last_recurrent_state = None , None
639595
640- # 2 .2: Process the remaining part
596+ # 3 .2: process the remaining part
641597 if attn_metadata .num_prefills > 0 :
642598 initial_state = ssm_state [non_spec_state_indices_tensor ].contiguous ()
643599 initial_state [~ has_initial_state , ...] = 0
@@ -680,20 +636,30 @@ def _forward_core(
680636 else :
681637 core_attn_out_non_spec , last_recurrent_state = None , None
682638
683- # 3. Merge core attention output
639+ # Merge core attention output
684640 if spec_sequence_masks is not None and core_attn_out_non_spec is not None :
685- merged_out = torch .empty (
641+ core_attn_out = torch .empty (
686642 (1 , num_actual_tokens , * core_attn_out_spec .shape [2 :]),
687643 dtype = core_attn_out_non_spec .dtype ,
688644 device = core_attn_out_non_spec .device ,
689645 )
690- merged_out .index_copy_ (1 , spec_token_indx , core_attn_out_spec )
691- merged_out .index_copy_ (1 , non_spec_token_indx , core_attn_out_non_spec )
692- core_attn_out [: num_actual_tokens ] = merged_out . squeeze ( 0 )
646+ core_attn_out .index_copy_ (1 , spec_token_indx , core_attn_out_spec )
647+ core_attn_out .index_copy_ (1 , non_spec_token_indx , core_attn_out_non_spec )
648+
693649 elif spec_sequence_masks is not None :
694- core_attn_out [: num_actual_tokens ] = core_attn_out_spec . squeeze ( 0 )
650+ core_attn_out = core_attn_out_spec
695651 else :
696- core_attn_out [:num_actual_tokens ] = core_attn_out_non_spec .squeeze (0 )
652+ core_attn_out = core_attn_out_non_spec
653+
654+ z_shape_og = z .shape
655+ # reshape input data into 2D tensor
656+ core_attn_out = core_attn_out .reshape (- 1 , core_attn_out .shape [- 1 ])
657+ z = z .reshape (- 1 , z .shape [- 1 ])
658+ core_attn_out = self .norm (core_attn_out , z )
659+ core_attn_out = core_attn_out .reshape (z_shape_og )
660+ core_attn_out = rearrange (core_attn_out , "... h d -> ... (h d)" )
661+
662+ output [:num_actual_tokens ], _ = self .out_proj (core_attn_out )
697663
698664
699665class Qwen3NextAttention (nn .Module ):
@@ -1304,44 +1270,29 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
13041270 return self .model .get_expert_mapping ()
13051271
13061272
1307- def gdn_attention_core (
1308- mixed_qkv : torch .Tensor ,
1309- b : torch .Tensor ,
1310- a : torch .Tensor ,
1311- core_attn_out : torch .Tensor ,
1273+ def gdn_attention (
1274+ hidden_states : torch .Tensor ,
1275+ output : torch .Tensor ,
13121276 layer_name : str ,
13131277) -> None :
1314- """
1315- Custom op for the core attention computation.
1316- Only handles the convolution + recurrent attention part.
1317- Input/output projections are handled outside this op.
1318- """
13191278 forward_context : ForwardContext = get_forward_context ()
13201279 self = forward_context .no_compile_layers [layer_name ]
1321- self ._forward_core (
1322- mixed_qkv = mixed_qkv ,
1323- b = b ,
1324- a = a ,
1325- core_attn_out = core_attn_out ,
1326- )
1280+ self ._forward (hidden_states = hidden_states , output = output )
13271281
13281282
1329- def gdn_attention_core_fake (
1330- mixed_qkv : torch .Tensor ,
1331- b : torch .Tensor ,
1332- a : torch .Tensor ,
1333- core_attn_out : torch .Tensor ,
1283+ def gdn_attention_fake (
1284+ hidden_states : torch .Tensor ,
1285+ output : torch .Tensor ,
13341286 layer_name : str ,
13351287) -> None :
1336- """Fake implementation for torch.compile."""
13371288 return
13381289
13391290
13401291direct_register_custom_op (
1341- op_name = "gdn_attention_core " ,
1342- op_func = gdn_attention_core ,
1343- mutates_args = ["core_attn_out " ],
1344- fake_impl = gdn_attention_core_fake ,
1292+ op_name = "gdn_attention " ,
1293+ op_func = gdn_attention ,
1294+ mutates_args = ["output " ],
1295+ fake_impl = gdn_attention_fake ,
13451296)
13461297
13471298
0 commit comments