3030from vllm .forward_context import ForwardContext , get_forward_context
3131from vllm .logger import init_logger
3232from vllm .model_executor .layers .fla .ops import (
33- RMSNormGated ,
3433 chunk_gated_delta_rule ,
3534 fused_recurrent_gated_delta_rule ,
3635)
3736from vllm .model_executor .layers .fused_moe import SharedFusedMoE
38- from vllm .model_executor .layers .layernorm import GemmaRMSNorm as Qwen3NextRMSNorm
37+ from vllm .model_executor .layers .layernorm import (
38+ GemmaRMSNorm as Qwen3NextRMSNorm ,
39+ )
40+ from vllm .model_executor .layers .layernorm import RMSNormGated
3941from vllm .model_executor .layers .linear import (
4042 ColumnParallelLinear ,
4143 QKVParallelLinear ,
@@ -436,17 +438,66 @@ def forward(
436438 hidden_states : torch .Tensor ,
437439 output : torch .Tensor ,
438440 ):
439- return torch .ops .vllm .gdn_attention (
440- hidden_states ,
441- output ,
441+ """
442+ Forward pass with three parts:
443+ 1. Input projection
444+ 2. Core attention (custom op)
445+ 3. Output projection
446+ """
447+ num_tokens = hidden_states .size (0 )
448+
449+ # ============================================================
450+ # Part 1: Input Projection
451+ # ============================================================
452+ projected_states_qkvz , _ = self .in_proj_qkvz (hidden_states )
453+ projected_states_ba , _ = self .in_proj_ba (hidden_states )
454+ query , key , value , z , b , a = self .fix_query_key_value_ordering (
455+ projected_states_qkvz , projected_states_ba
456+ )
457+ query , key , value = map (
458+ lambda x : rearrange (x , "l p d -> l (p d)" ), (query , key , value )
459+ )
460+ mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
461+
462+ # ============================================================
463+ # Part 2: Core Attention (Custom Op)
464+ # ============================================================
465+ core_attn_out = torch .empty (
466+ (num_tokens , self .num_v_heads // self .tp_size , self .head_v_dim ),
467+ dtype = hidden_states .dtype ,
468+ device = hidden_states .device ,
469+ )
470+
471+ torch .ops .vllm .gdn_attention_core (
472+ mixed_qkv ,
473+ b ,
474+ a ,
475+ core_attn_out ,
442476 self .prefix ,
443477 )
444478
445- def _forward (
479+ # ============================================================
480+ # Part 3: Output Projection
481+ # ============================================================
482+ z_shape_og = z .shape
483+ # Reshape input data into 2D tensor
484+ core_attn_out = core_attn_out .reshape (- 1 , core_attn_out .shape [- 1 ])
485+ z = z .reshape (- 1 , z .shape [- 1 ])
486+ core_attn_out = self .norm (core_attn_out , z )
487+ core_attn_out = core_attn_out .reshape (z_shape_og )
488+ core_attn_out = rearrange (core_attn_out , "... h d -> ... (h d)" )
489+ output [:num_tokens ], _ = self .out_proj (core_attn_out )
490+
491+ def _forward_core (
446492 self ,
447- hidden_states : torch .Tensor ,
448- output : torch .Tensor ,
493+ mixed_qkv : torch .Tensor ,
494+ b : torch .Tensor ,
495+ a : torch .Tensor ,
496+ core_attn_out : torch .Tensor ,
449497 ):
498+ """
499+ Core attention computation (called by custom op).
500+ """
450501 forward_context = get_forward_context ()
451502 attn_metadata : AttentionMetadata = forward_context .attn_metadata
452503
@@ -471,18 +522,11 @@ def _forward(
471522 num_actual_tokens = attn_metadata .num_actual_tokens
472523 num_accepted_tokens = attn_metadata .num_accepted_tokens
473524
474- # 1. Set up dimensions for reshapes later
475- projected_states_qkvz , _ = self .in_proj_qkvz (hidden_states [:num_actual_tokens ])
476- projected_states_ba , _ = self .in_proj_ba (hidden_states [:num_actual_tokens ])
477- query , key , value , z , b , a = self .fix_query_key_value_ordering (
478- projected_states_qkvz , projected_states_ba
479- )
480- query , key , value = map (
481- lambda x : rearrange (x , "l p d -> l (p d)" ), (query , key , value )
482- )
483- mixed_qkv = torch .cat ((query , key , value ), dim = - 1 )
525+ mixed_qkv = mixed_qkv [:num_actual_tokens ]
526+ b = b [:num_actual_tokens ]
527+ a = a [:num_actual_tokens ]
484528
485- # 2 . Convolution sequence transformation
529+ # 1 . Convolution sequence transformation
486530 conv_weights = self .conv1d .weight .view (
487531 self .conv1d .weight .size (0 ), self .conv1d .weight .size (2 )
488532 )
@@ -498,7 +542,7 @@ def _forward(
498542 mixed_qkv_spec = None
499543 mixed_qkv_non_spec = mixed_qkv
500544
501- # 2 .1: process the mutli -query part
545+ # 1 .1: Process the multi -query part
502546 if spec_sequence_masks is not None :
503547 mixed_qkv_spec = causal_conv1d_update (
504548 mixed_qkv_spec ,
@@ -515,7 +559,7 @@ def _forward(
515559 validate_data = False ,
516560 )
517561
518- # 2 .2: process the remaining part
562+ # 1 .2: Process the remaining part
519563 if attn_metadata .num_prefills > 0 :
520564 mixed_qkv_non_spec_T = mixed_qkv_non_spec .transpose (0 , 1 )
521565 # - "cache_indices" updates the conv_state cache in positions
@@ -573,9 +617,9 @@ def _forward(
573617 g_non_spec = g
574618 beta_non_spec = beta
575619
576- # 3 . Recurrent attention
620+ # 2 . Recurrent attention
577621
578- # 3 .1: process the mutlti -query part
622+ # 2 .1: Process the multi -query part
579623 if spec_sequence_masks is not None :
580624 core_attn_out_spec , last_recurrent_state = fused_recurrent_gated_delta_rule (
581625 q = query_spec ,
@@ -593,7 +637,7 @@ def _forward(
593637 else :
594638 core_attn_out_spec , last_recurrent_state = None , None
595639
596- # 3 .2: process the remaining part
640+ # 2 .2: Process the remaining part
597641 if attn_metadata .num_prefills > 0 :
598642 initial_state = ssm_state [non_spec_state_indices_tensor ].contiguous ()
599643 initial_state [~ has_initial_state , ...] = 0
@@ -636,30 +680,20 @@ def _forward(
636680 else :
637681 core_attn_out_non_spec , last_recurrent_state = None , None
638682
639- # Merge core attention output
683+ # 3. Merge core attention output
640684 if spec_sequence_masks is not None and core_attn_out_non_spec is not None :
641- core_attn_out = torch .empty (
685+ merged_out = torch .empty (
642686 (1 , num_actual_tokens , * core_attn_out_spec .shape [2 :]),
643687 dtype = core_attn_out_non_spec .dtype ,
644688 device = core_attn_out_non_spec .device ,
645689 )
646- core_attn_out .index_copy_ (1 , spec_token_indx , core_attn_out_spec )
647- core_attn_out .index_copy_ (1 , non_spec_token_indx , core_attn_out_non_spec )
648-
690+ merged_out .index_copy_ (1 , spec_token_indx , core_attn_out_spec )
691+ merged_out .index_copy_ (1 , non_spec_token_indx , core_attn_out_non_spec )
692+ core_attn_out [: num_actual_tokens ] = merged_out . squeeze ( 0 )
649693 elif spec_sequence_masks is not None :
650- core_attn_out = core_attn_out_spec
694+ core_attn_out [: num_actual_tokens ] = core_attn_out_spec . squeeze ( 0 )
651695 else :
652- core_attn_out = core_attn_out_non_spec
653-
654- z_shape_og = z .shape
655- # reshape input data into 2D tensor
656- core_attn_out = core_attn_out .reshape (- 1 , core_attn_out .shape [- 1 ])
657- z = z .reshape (- 1 , z .shape [- 1 ])
658- core_attn_out = self .norm (core_attn_out , z )
659- core_attn_out = core_attn_out .reshape (z_shape_og )
660- core_attn_out = rearrange (core_attn_out , "... h d -> ... (h d)" )
661-
662- output [:num_actual_tokens ], _ = self .out_proj (core_attn_out )
696+ core_attn_out [:num_actual_tokens ] = core_attn_out_non_spec .squeeze (0 )
663697
664698
665699class Qwen3NextAttention (nn .Module ):
@@ -1270,29 +1304,44 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
12701304 return self .model .get_expert_mapping ()
12711305
12721306
1273- def gdn_attention (
1274- hidden_states : torch .Tensor ,
1275- output : torch .Tensor ,
1307+ def gdn_attention_core (
1308+ mixed_qkv : torch .Tensor ,
1309+ b : torch .Tensor ,
1310+ a : torch .Tensor ,
1311+ core_attn_out : torch .Tensor ,
12761312 layer_name : str ,
12771313) -> None :
1314+ """
1315+ Custom op for the core attention computation.
1316+ Only handles the convolution + recurrent attention part.
1317+ Input/output projections are handled outside this op.
1318+ """
12781319 forward_context : ForwardContext = get_forward_context ()
12791320 self = forward_context .no_compile_layers [layer_name ]
1280- self ._forward (hidden_states = hidden_states , output = output )
1321+ self ._forward_core (
1322+ mixed_qkv = mixed_qkv ,
1323+ b = b ,
1324+ a = a ,
1325+ core_attn_out = core_attn_out ,
1326+ )
12811327
12821328
1283- def gdn_attention_fake (
1284- hidden_states : torch .Tensor ,
1285- output : torch .Tensor ,
1329+ def gdn_attention_core_fake (
1330+ mixed_qkv : torch .Tensor ,
1331+ b : torch .Tensor ,
1332+ a : torch .Tensor ,
1333+ core_attn_out : torch .Tensor ,
12861334 layer_name : str ,
12871335) -> None :
1336+ """Fake implementation for torch.compile."""
12881337 return
12891338
12901339
12911340direct_register_custom_op (
1292- op_name = "gdn_attention " ,
1293- op_func = gdn_attention ,
1294- mutates_args = ["output " ],
1295- fake_impl = gdn_attention_fake ,
1341+ op_name = "gdn_attention_core " ,
1342+ op_func = gdn_attention_core ,
1343+ mutates_args = ["core_attn_out " ],
1344+ fake_impl = gdn_attention_core_fake ,
12961345)
12971346
12981347
0 commit comments