diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f452ba871582..34b62125151c 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -551,10 +551,7 @@ def _forward( mixed_qkv_non_spec ) - beta = b.sigmoid() - # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) - g = fused_gdn_gating(self.A_log, a, self.dt_bias) - g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) + g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) if spec_sequence_masks is not None: if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: @@ -1296,12 +1293,13 @@ def gdn_attention_fake( ) -# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) @triton.jit def fused_gdn_gating_kernel( g, + beta_output, A_log, a, + b, dt_bias, seq_len, NUM_HEADS: tl.constexpr, @@ -1315,6 +1313,7 @@ def fused_gdn_gating_kernel( mask = head_off < NUM_HEADS blk_A_log = tl.load(A_log + head_off, mask=mask) blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask) # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) @@ -1323,20 +1322,42 @@ def fused_gdn_gating_kernel( ) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + # compute beta_output = sigmoid(b) + blk_beta = 1.0 / (1.0 + tl.exp(-blk_b.to(tl.float32))) + tl.store(beta_output + off, blk_beta.to(beta_output.dtype.element_ty), mask=mask) def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, + b: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fused computation of g and beta for Gated Delta Net. + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + beta_output = b.sigmoid() + TODO maybe use torch.compile to replace this triton kernel + """ batch, num_heads = a.shape seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) - g = torch.empty_like(a, dtype=torch.float32) + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) fused_gdn_gating_kernel[grid]( - g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 + g, + beta_output, + A_log, + a, + b, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + 8, + num_warps=1, ) - return g + return g, beta_output