Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,7 @@ def _forward(
mixed_qkv_non_spec
)

beta = b.sigmoid()
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta))
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)

if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
Expand Down Expand Up @@ -1296,12 +1293,13 @@ def gdn_attention_fake(
)


# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
@triton.jit
def fused_gdn_gating_kernel(
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
Expand All @@ -1315,6 +1313,7 @@ def fused_gdn_gating_kernel(
mask = head_off < NUM_HEADS
blk_A_log = tl.load(A_log + head_off, mask=mask)
blk_a = tl.load(a + off, mask=mask)
blk_b = tl.load(b + off, mask=mask)
blk_bias = tl.load(dt_bias + head_off, mask=mask)
# If the model is loaded in fp16, without the .float() here, A might be -inf
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
Expand All @@ -1323,20 +1322,42 @@ def fused_gdn_gating_kernel(
)
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
# compute beta_output = sigmoid(b)
blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32)))
tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask)


def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused computation of g and beta for Gated Delta Net.
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
beta_output = b.sigmoid()
TODO maybe use torch.compile to replace this triton kernel
"""
batch, num_heads = a.shape
seq_len = 1
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
g = torch.empty_like(a, dtype=torch.float32)
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device)
fused_gdn_gating_kernel[grid](
g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1
g,
beta_output,
A_log,
a,
b,
dt_bias,
seq_len,
num_heads,
beta,
threshold,
8,
num_warps=1,
)
return g
return g, beta_output