Skip to content

Commit c18f88c

Browse files
authored
[Kernel] Fuse computation of g and beta for Gated Delta Net (#28095)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent 6fd0df8 commit c18f88c

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

vllm/model_executor/models/qwen3_next.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,7 @@ def _forward(
551551
mixed_qkv_non_spec
552552
)
553553

554-
beta = b.sigmoid()
555-
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
556-
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
557-
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))
554+
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
558555

559556
if spec_sequence_masks is not None:
560557
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
@@ -1289,12 +1286,13 @@ def gdn_attention_fake(
12891286
)
12901287

12911288

1292-
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
12931289
@triton.jit
12941290
def fused_gdn_gating_kernel(
12951291
g,
1292+
beta_output,
12961293
A_log,
12971294
a,
1295+
b,
12981296
dt_bias,
12991297
seq_len,
13001298
NUM_HEADS: tl.constexpr,
@@ -1308,6 +1306,7 @@ def fused_gdn_gating_kernel(
13081306
mask = head_off < NUM_HEADS
13091307
blk_A_log = tl.load(A_log + head_off, mask=mask)
13101308
blk_a = tl.load(a + off, mask=mask)
1309+
blk_b = tl.load(b + off, mask=mask)
13111310
blk_bias = tl.load(dt_bias + head_off, mask=mask)
13121311
# If the model is loaded in fp16, without the .float() here, A might be -inf
13131312
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
@@ -1316,20 +1315,42 @@ def fused_gdn_gating_kernel(
13161315
)
13171316
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
13181317
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
1318+
# compute beta_output = sigmoid(b)
1319+
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
1320+
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)
13191321

13201322

13211323
def fused_gdn_gating(
13221324
A_log: torch.Tensor,
13231325
a: torch.Tensor,
1326+
b: torch.Tensor,
13241327
dt_bias: torch.Tensor,
13251328
beta: float = 1.0,
13261329
threshold: float = 20.0,
1327-
) -> torch.Tensor:
1330+
) -> tuple[torch.Tensor, torch.Tensor]:
1331+
"""
1332+
Fused computation of g and beta for Gated Delta Net.
1333+
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
1334+
beta_output = b.sigmoid()
1335+
TODO maybe use torch.compile to replace this triton kernel
1336+
"""
13281337
batch, num_heads = a.shape
13291338
seq_len = 1
13301339
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
1331-
g = torch.empty_like(a, dtype=torch.float32)
1340+
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
1341+
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
13321342
fused_gdn_gating_kernel[grid](
1333-
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
1343+
g,
1344+
beta_output,
1345+
A_log,
1346+
a,
1347+
b,
1348+
dt_bias,
1349+
seq_len,
1350+
num_heads,
1351+
beta,
1352+
threshold,
1353+
8,
1354+
num_warps=1,
13341355
)
1335-
return g
1356+
return g, beta_output

0 commit comments

Comments
 (0)