@@ -551,10 +551,7 @@ def _forward(
551551 mixed_qkv_non_spec
552552 )
553553
554- beta = b .sigmoid ()
555- # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
556- g = fused_gdn_gating (self .A_log , a , self .dt_bias )
557- g , beta = map (lambda x : rearrange (x , "l d -> 1 l d" ), (g , beta ))
554+ g , beta = fused_gdn_gating (self .A_log , a , b , self .dt_bias )
558555
559556 if spec_sequence_masks is not None :
560557 if attn_metadata .num_prefills == 0 and attn_metadata .num_decodes == 0 :
@@ -1289,12 +1286,13 @@ def gdn_attention_fake(
12891286)
12901287
12911288
1292- # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
12931289@triton .jit
12941290def fused_gdn_gating_kernel (
12951291 g ,
1292+ beta_output ,
12961293 A_log ,
12971294 a ,
1295+ b ,
12981296 dt_bias ,
12991297 seq_len ,
13001298 NUM_HEADS : tl .constexpr ,
@@ -1308,6 +1306,7 @@ def fused_gdn_gating_kernel(
13081306 mask = head_off < NUM_HEADS
13091307 blk_A_log = tl .load (A_log + head_off , mask = mask )
13101308 blk_a = tl .load (a + off , mask = mask )
1309+ blk_b = tl .load (b + off , mask = mask )
13111310 blk_bias = tl .load (dt_bias + head_off , mask = mask )
13121311 # If the model is loaded in fp16, without the .float() here, A might be -inf
13131312 x = blk_a .to (tl .float32 ) + blk_bias .to (tl .float32 )
@@ -1316,20 +1315,42 @@ def fused_gdn_gating_kernel(
13161315 )
13171316 blk_g = - tl .exp (blk_A_log .to (tl .float32 )) * softplus_x
13181317 tl .store (g + off , blk_g .to (g .dtype .element_ty ), mask = mask )
1318+ # compute beta_output = sigmoid(b)
1319+ blk_beta = 1.0 / (1.0 + tl .exp (- blk_b .to (tl .float32 )))
1320+ tl .store (beta_output + off , blk_beta .to (beta_output .dtype .element_ty ), mask = mask )
13191321
13201322
13211323def fused_gdn_gating (
13221324 A_log : torch .Tensor ,
13231325 a : torch .Tensor ,
1326+ b : torch .Tensor ,
13241327 dt_bias : torch .Tensor ,
13251328 beta : float = 1.0 ,
13261329 threshold : float = 20.0 ,
1327- ) -> torch .Tensor :
1330+ ) -> tuple [torch .Tensor , torch .Tensor ]:
1331+ """
1332+ Fused computation of g and beta for Gated Delta Net.
1333+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
1334+ beta_output = b.sigmoid()
1335+ TODO maybe use torch.compile to replace this triton kernel
1336+ """
13281337 batch , num_heads = a .shape
13291338 seq_len = 1
13301339 grid = (batch , seq_len , triton .cdiv (num_heads , 8 ))
1331- g = torch .empty_like (a , dtype = torch .float32 )
1340+ g = torch .empty (1 , batch , num_heads , dtype = torch .float32 , device = a .device )
1341+ beta_output = torch .empty (1 , batch , num_heads , dtype = torch .float32 , device = b .device )
13321342 fused_gdn_gating_kernel [grid ](
1333- g , A_log , a , dt_bias , seq_len , num_heads , beta , threshold , 8 , num_warps = 1
1343+ g ,
1344+ beta_output ,
1345+ A_log ,
1346+ a ,
1347+ b ,
1348+ dt_bias ,
1349+ seq_len ,
1350+ num_heads ,
1351+ beta ,
1352+ threshold ,
1353+ 8 ,
1354+ num_warps = 1 ,
13341355 )
1335- return g
1356+ return g , beta_output
0 commit comments