4040
4141
4242def  kda_attention (
43-     hidden_states : torch .Tensor ,
44-     output : torch .Tensor ,
43+     q_proj_states : torch .Tensor ,
44+     k_proj_states : torch .Tensor ,
45+     v_proj_states : torch .Tensor ,
46+     g1 : torch .Tensor ,
47+     g2 : torch .Tensor ,
48+     beta : torch .Tensor ,
49+     core_attn_out : torch .Tensor ,
4550    layer_name : str ,
4651) ->  None :
4752    forward_context : ForwardContext  =  get_forward_context ()
4853    self  =  forward_context .no_compile_layers [layer_name ]
49-     self ._forward (hidden_states = hidden_states , output = output )
54+     self ._forward (
55+         q_proj_states = q_proj_states ,
56+         k_proj_states = k_proj_states ,
57+         v_proj_states = v_proj_states ,
58+         g1 = g1 ,
59+         g2 = g2 ,
60+         beta = beta ,
61+         core_attn_out = core_attn_out ,
62+     )
5063
5164
5265def  kda_attention_fake (
53-     hidden_states : torch .Tensor ,
54-     output : torch .Tensor ,
66+     q_proj_states : torch .Tensor ,
67+     k_proj_states : torch .Tensor ,
68+     v_proj_states : torch .Tensor ,
69+     g1 : torch .Tensor ,
70+     g2 : torch .Tensor ,
71+     beta : torch .Tensor ,
72+     core_attn_out : torch .Tensor ,
5573    layer_name : str ,
5674) ->  None :
5775    return 
@@ -60,7 +78,7 @@ def kda_attention_fake(
6078direct_register_custom_op (
6179    op_name = "kda_attention" ,
6280    op_func = kda_attention ,
63-     mutates_args = ["output " ],
81+     mutates_args = ["core_attn_out " ],
6482    fake_impl = kda_attention_fake ,
6583)
6684
@@ -241,37 +259,56 @@ def forward(
241259        hidden_states : torch .Tensor ,
242260        positions : torch .Tensor ,
243261        output : torch .Tensor ,
244-     ) ->  None :
245-         return  torch .ops .vllm .kda_attention (
246-             hidden_states ,
247-             output ,
262+     ) ->  torch .Tensor :
263+         num_tokens  =  hidden_states .size (0 )
264+         q  =  self .q_proj (hidden_states )[0 ]
265+         k  =  self .k_proj (hidden_states )[0 ]
266+         v  =  self .v_proj (hidden_states )[0 ]
267+ 
268+         beta  =  self .b_proj (hidden_states )[0 ].float ().sigmoid ()
269+         g1  =  self .f_b_proj (self .f_a_proj (hidden_states )[0 ])[0 ]
270+         g1  =  fused_kda_gate (g1 , self .A_log , self .head_dim , g_bias = self .dt_bias )
271+         beta  =  beta .unsqueeze (0 )
272+         g1  =  g1 .unsqueeze (0 )
273+ 
274+         g_proj_states  =  self .g_b_proj (self .g_a_proj (hidden_states )[0 ])[0 ]
275+         g2  =  rearrange (g_proj_states , "... (h d) -> ... h d" , d = self .head_dim )
276+ 
277+         core_attn_out  =  torch .zeros (
278+             (1 , num_tokens , self .local_num_heads , self .head_dim ),
279+             dtype = hidden_states .dtype ,
280+             device = hidden_states .device ,
281+         )
282+         torch .ops .vllm .kda_attention (
283+             q ,
284+             k ,
285+             v ,
286+             g1 ,
287+             g2 ,
288+             beta ,
289+             core_attn_out ,
248290            self .prefix ,
249291        )
292+         core_attn_out  =  self .o_norm (core_attn_out , g2 )
293+         core_attn_out  =  rearrange (core_attn_out , "1 n h d -> n (h d)" )
294+ 
295+         return  self .o_proj (core_attn_out )[0 ]
250296
251297    def  _forward (
252298        self ,
253-         hidden_states : torch .Tensor ,
254-         output : torch .Tensor ,
299+         q_proj_states : torch .Tensor ,
300+         k_proj_states : torch .Tensor ,
301+         v_proj_states : torch .Tensor ,
302+         g1 : torch .Tensor ,
303+         g2 : torch .Tensor ,
304+         beta : torch .Tensor ,
305+         core_attn_out : torch .Tensor ,
255306    ) ->  None :
256307        forward_context  =  get_forward_context ()
257308        attn_metadata : AttentionMetadata  =  forward_context .attn_metadata 
258309
259310        if  attn_metadata  is  None :
260-             # V1 profile run 
261-             # Mimic the memory allocation in the real run 
262-             q  =  torch .empty_like (hidden_states )
263-             k  =  torch .empty_like (hidden_states )
264-             v  =  torch .empty_like (hidden_states )
265-             g  =  hidden_states .new_empty (
266-                 hidden_states .size (0 ),
267-                 self .local_num_heads ,
268-                 self .head_dim ,
269-                 dtype = torch .float32 ,
270-             )
271-             beta  =  torch .empty (
272-                 hidden_states .size (0 ), self .local_num_heads , dtype = torch .float32 
273-             )
274-             core_attn_out  =  torch .empty_like (hidden_states )
311+             #     # V1 profile run 
275312            return 
276313
277314        assert  isinstance (attn_metadata , dict )
@@ -288,10 +325,6 @@ def _forward(
288325        conv_state_k  =  conv_state_k .transpose (- 1 , - 2 )
289326        conv_state_v  =  conv_state_v .transpose (- 1 , - 2 )
290327
291-         q_proj_states  =  self .q_proj (hidden_states )[0 ]
292-         k_proj_states  =  self .k_proj (hidden_states )[0 ]
293-         v_proj_states  =  self .v_proj (hidden_states )[0 ]
294- 
295328        q_conv_weights  =  self .q_conv1d .weight .view (
296329            self .q_conv1d .weight .size (0 ), self .q_conv1d .weight .size (2 )
297330        )
@@ -374,14 +407,6 @@ def _forward(
374407            lambda  x : rearrange (x , "n (h d) -> 1 n h d" , d = self .head_dim ), (q , k , v )
375408        )
376409
377-         beta  =  self .b_proj (hidden_states )[0 ].float ().sigmoid ()
378- 
379-         g  =  self .f_b_proj (self .f_a_proj (hidden_states )[0 ])[0 ]
380-         g  =  fused_kda_gate (g , self .A_log , self .head_dim , g_bias = self .dt_bias )
381- 
382-         beta  =  beta .unsqueeze (0 )
383-         g  =  g .unsqueeze (0 )
384- 
385410        if  attn_metadata .num_prefills  >  0 :
386411            zero_idx  =  non_spec_state_indices_tensor [~ has_initial_state ]
387412            recurrent_state [zero_idx ] =  0 
@@ -393,7 +418,7 @@ def _forward(
393418                q = q ,
394419                k = k ,
395420                v = v ,
396-                 g = g ,
421+                 g = g1 ,
397422                beta = beta ,
398423                initial_state = initial_state ,
399424                output_final_state = True ,
@@ -410,17 +435,12 @@ def _forward(
410435                q = q ,
411436                k = k ,
412437                v = v ,
413-                 g = g ,
438+                 g = g1 ,
414439                beta = beta ,
415440                initial_state = recurrent_state ,
416441                use_qk_l2norm_in_kernel = True ,
417442                cu_seqlens = non_spec_query_start_loc ,
418443                ssm_state_indices = non_spec_state_indices_tensor ,
419444            )
420- 
421-         g_proj_states  =  self .g_b_proj (self .g_a_proj (hidden_states )[0 ])[0 ]
422-         g  =  rearrange (g_proj_states , "... (h d) -> ... h d" , d = self .head_dim )
423-         core_attn_out  =  self .o_norm (core_attn_out_non_spec , g )
424-         core_attn_out  =  rearrange (core_attn_out , "1 n h d -> n (h d)" )
425- 
426-         output [:] =  self .o_proj (core_attn_out )[0 ]
445+         assert  core_attn_out_non_spec .shape  ==  core_attn_out .shape 
446+         core_attn_out [:] =  core_attn_out_non_spec 
0 commit comments