Skip to content

Commit 3d3da84

Browse files
authored
Revert "[PERF] Decouple projections from GDN custom op (vllm-project#27512)"
This reverts commit 5fd8f02.
1 parent 2d977a7 commit 3d3da84

File tree

3 files changed

+53
-204
lines changed

3 files changed

+53
-204
lines changed

vllm/config/compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ class CompilationConfig:
462462
"vllm::short_conv",
463463
"vllm::linear_attention",
464464
"vllm::plamo2_mamba_mixer",
465-
"vllm::gdn_attention_core",
465+
"vllm::gdn_attention",
466466
"vllm::kda_attention",
467467
"vllm::sparse_attn_indexer",
468468
]

vllm/model_executor/layers/layernorm.py

Lines changed: 0 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
rms_norm_batch_invariant,
1313
vllm_is_batch_invariant,
1414
)
15-
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
1615
from vllm.platforms import current_platform
1716
from vllm.utils.torch_utils import direct_register_custom_op
1817

@@ -370,107 +369,6 @@ def forward_cuda(
370369
return self.forward_native(x, residual)
371370

372371

373-
@CustomOp.register("rms_norm_gated")
374-
class RMSNormGated(CustomOp):
375-
"""RMS Normalization with optional gating.
376-
377-
This is a native PyTorch implementation that supports:
378-
- Standard RMS normalization
379-
- Group RMS normalization
380-
- Optional gating with SiLU activation
381-
"""
382-
383-
def __init__(
384-
self,
385-
hidden_size: int,
386-
eps: float = 1e-5,
387-
group_size: int | None = None,
388-
norm_before_gate: bool = False,
389-
device: torch.device | None = None,
390-
dtype: torch.dtype | None = None,
391-
):
392-
"""Initialize RMSNormGated.
393-
394-
Args:
395-
hidden_size: Size of the hidden dimension
396-
eps: Epsilon for numerical stability
397-
group_size: If not None, do GroupNorm with each group
398-
having group_size elements.
399-
group_size=None is equivalent to group_size=hidden_size
400-
(i.e. there's only 1 group).
401-
norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
402-
If False and z is provided: out = norm(x * silu(z))
403-
device: Device to create parameters on
404-
dtype: Data type for parameters
405-
"""
406-
factory_kwargs = {"device": device, "dtype": dtype}
407-
super().__init__()
408-
self.eps = eps
409-
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
410-
self.register_parameter("bias", None)
411-
self.group_size = group_size
412-
self.norm_before_gate = norm_before_gate
413-
self.reset_parameters()
414-
415-
def reset_parameters(self):
416-
torch.nn.init.ones_(self.weight)
417-
418-
def forward_native(
419-
self, x: torch.Tensor, z: torch.Tensor | None = None
420-
) -> torch.Tensor:
421-
"""
422-
Native PyTorch implementation of RMS normalization with gating.
423-
424-
Args:
425-
x: Input tensor
426-
z: Optional gating tensor
427-
428-
Returns:
429-
Normalized (and optionally gated) tensor
430-
431-
If z is not None:
432-
- norm_before_gate=True: out = norm(x) * silu(z)
433-
- norm_before_gate=False: out = norm(x * silu(z))
434-
"""
435-
# Apply gating before normalization if needed
436-
if z is not None and not self.norm_before_gate:
437-
x = x * F.silu(z)
438-
439-
# RMS Normalization
440-
if self.group_size is None:
441-
# Standard RMS norm across the last dimension
442-
variance = x.pow(2).mean(dim=-1, keepdim=True)
443-
x_normed = x * torch.rsqrt(variance + self.eps)
444-
out = x_normed * self.weight
445-
else:
446-
# Group RMS norm
447-
from einops import rearrange
448-
449-
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
450-
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
451-
x_normed = x_group * torch.rsqrt(variance + self.eps)
452-
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
453-
454-
# Apply gating after normalization if needed
455-
if z is not None and self.norm_before_gate:
456-
out = out * F.silu(z)
457-
458-
return out
459-
460-
def forward_cuda(
461-
self, x: torch.Tensor, z: torch.Tensor | None = None
462-
) -> torch.Tensor:
463-
return rmsnorm_fn(
464-
x,
465-
self.weight,
466-
self.bias,
467-
z=z,
468-
eps=self.eps,
469-
group_size=self.group_size,
470-
norm_before_gate=self.norm_before_gate,
471-
)
472-
473-
474372
class LayerNorm(nn.Module):
475373
"""
476374
Layer Normalization.

vllm/model_executor/models/qwen3_next.py

Lines changed: 52 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,12 @@
3030
from vllm.forward_context import ForwardContext, get_forward_context
3131
from vllm.logger import init_logger
3232
from vllm.model_executor.layers.fla.ops import (
33+
RMSNormGated,
3334
chunk_gated_delta_rule,
3435
fused_recurrent_gated_delta_rule,
3536
)
3637
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
37-
from vllm.model_executor.layers.layernorm import (
38-
GemmaRMSNorm as Qwen3NextRMSNorm,
39-
)
40-
from vllm.model_executor.layers.layernorm import RMSNormGated
38+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm
4139
from vllm.model_executor.layers.linear import (
4240
ColumnParallelLinear,
4341
QKVParallelLinear,
@@ -438,66 +436,17 @@ def forward(
438436
hidden_states: torch.Tensor,
439437
output: torch.Tensor,
440438
):
441-
"""
442-
Forward pass with three parts:
443-
1. Input projection
444-
2. Core attention (custom op)
445-
3. Output projection
446-
"""
447-
num_tokens = hidden_states.size(0)
448-
449-
# ============================================================
450-
# Part 1: Input Projection
451-
# ============================================================
452-
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
453-
projected_states_ba, _ = self.in_proj_ba(hidden_states)
454-
query, key, value, z, b, a = self.fix_query_key_value_ordering(
455-
projected_states_qkvz, projected_states_ba
456-
)
457-
query, key, value = map(
458-
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
459-
)
460-
mixed_qkv = torch.cat((query, key, value), dim=-1)
461-
462-
# ============================================================
463-
# Part 2: Core Attention (Custom Op)
464-
# ============================================================
465-
core_attn_out = torch.zeros(
466-
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
467-
dtype=hidden_states.dtype,
468-
device=hidden_states.device,
469-
)
470-
471-
torch.ops.vllm.gdn_attention_core(
472-
mixed_qkv,
473-
b,
474-
a,
475-
core_attn_out,
439+
return torch.ops.vllm.gdn_attention(
440+
hidden_states,
441+
output,
476442
self.prefix,
477443
)
478444

479-
# ============================================================
480-
# Part 3: Output Projection
481-
# ============================================================
482-
z_shape_og = z.shape
483-
# Reshape input data into 2D tensor
484-
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
485-
z = z.reshape(-1, z.shape[-1])
486-
core_attn_out = self.norm(core_attn_out, z)
487-
core_attn_out = core_attn_out.reshape(z_shape_og)
488-
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
489-
output[:num_tokens], _ = self.out_proj(core_attn_out)
490-
491-
def _forward_core(
445+
def _forward(
492446
self,
493-
mixed_qkv: torch.Tensor,
494-
b: torch.Tensor,
495-
a: torch.Tensor,
496-
core_attn_out: torch.Tensor,
447+
hidden_states: torch.Tensor,
448+
output: torch.Tensor,
497449
):
498-
"""
499-
Core attention computation (called by custom op).
500-
"""
501450
forward_context = get_forward_context()
502451
attn_metadata: AttentionMetadata = forward_context.attn_metadata
503452

@@ -522,11 +471,18 @@ def _forward_core(
522471
num_actual_tokens = attn_metadata.num_actual_tokens
523472
num_accepted_tokens = attn_metadata.num_accepted_tokens
524473

525-
mixed_qkv = mixed_qkv[:num_actual_tokens]
526-
b = b[:num_actual_tokens]
527-
a = a[:num_actual_tokens]
474+
# 1. Set up dimensions for reshapes later
475+
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens])
476+
projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens])
477+
query, key, value, z, b, a = self.fix_query_key_value_ordering(
478+
projected_states_qkvz, projected_states_ba
479+
)
480+
query, key, value = map(
481+
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
482+
)
483+
mixed_qkv = torch.cat((query, key, value), dim=-1)
528484

529-
# 1. Convolution sequence transformation
485+
# 2. Convolution sequence transformation
530486
conv_weights = self.conv1d.weight.view(
531487
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
532488
)
@@ -542,7 +498,7 @@ def _forward_core(
542498
mixed_qkv_spec = None
543499
mixed_qkv_non_spec = mixed_qkv
544500

545-
# 1.1: Process the multi-query part
501+
# 2.1: process the mutli-query part
546502
if spec_sequence_masks is not None:
547503
mixed_qkv_spec = causal_conv1d_update(
548504
mixed_qkv_spec,
@@ -559,7 +515,7 @@ def _forward_core(
559515
validate_data=False,
560516
)
561517

562-
# 1.2: Process the remaining part
518+
# 2.2: process the remaining part
563519
if attn_metadata.num_prefills > 0:
564520
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
565521
# - "cache_indices" updates the conv_state cache in positions
@@ -617,9 +573,9 @@ def _forward_core(
617573
g_non_spec = g
618574
beta_non_spec = beta
619575

620-
# 2. Recurrent attention
576+
# 3. Recurrent attention
621577

622-
# 2.1: Process the multi-query part
578+
# 3.1: process the mutlti-query part
623579
if spec_sequence_masks is not None:
624580
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
625581
q=query_spec,
@@ -637,7 +593,7 @@ def _forward_core(
637593
else:
638594
core_attn_out_spec, last_recurrent_state = None, None
639595

640-
# 2.2: Process the remaining part
596+
# 3.2: process the remaining part
641597
if attn_metadata.num_prefills > 0:
642598
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
643599
initial_state[~has_initial_state, ...] = 0
@@ -680,20 +636,30 @@ def _forward_core(
680636
else:
681637
core_attn_out_non_spec, last_recurrent_state = None, None
682638

683-
# 3. Merge core attention output
639+
# Merge core attention output
684640
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
685-
merged_out = torch.empty(
641+
core_attn_out = torch.empty(
686642
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
687643
dtype=core_attn_out_non_spec.dtype,
688644
device=core_attn_out_non_spec.device,
689645
)
690-
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
691-
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
692-
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
646+
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
647+
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
648+
693649
elif spec_sequence_masks is not None:
694-
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
650+
core_attn_out = core_attn_out_spec
695651
else:
696-
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
652+
core_attn_out = core_attn_out_non_spec
653+
654+
z_shape_og = z.shape
655+
# reshape input data into 2D tensor
656+
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
657+
z = z.reshape(-1, z.shape[-1])
658+
core_attn_out = self.norm(core_attn_out, z)
659+
core_attn_out = core_attn_out.reshape(z_shape_og)
660+
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
661+
662+
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
697663

698664

699665
class Qwen3NextAttention(nn.Module):
@@ -1304,44 +1270,29 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
13041270
return self.model.get_expert_mapping()
13051271

13061272

1307-
def gdn_attention_core(
1308-
mixed_qkv: torch.Tensor,
1309-
b: torch.Tensor,
1310-
a: torch.Tensor,
1311-
core_attn_out: torch.Tensor,
1273+
def gdn_attention(
1274+
hidden_states: torch.Tensor,
1275+
output: torch.Tensor,
13121276
layer_name: str,
13131277
) -> None:
1314-
"""
1315-
Custom op for the core attention computation.
1316-
Only handles the convolution + recurrent attention part.
1317-
Input/output projections are handled outside this op.
1318-
"""
13191278
forward_context: ForwardContext = get_forward_context()
13201279
self = forward_context.no_compile_layers[layer_name]
1321-
self._forward_core(
1322-
mixed_qkv=mixed_qkv,
1323-
b=b,
1324-
a=a,
1325-
core_attn_out=core_attn_out,
1326-
)
1280+
self._forward(hidden_states=hidden_states, output=output)
13271281

13281282

1329-
def gdn_attention_core_fake(
1330-
mixed_qkv: torch.Tensor,
1331-
b: torch.Tensor,
1332-
a: torch.Tensor,
1333-
core_attn_out: torch.Tensor,
1283+
def gdn_attention_fake(
1284+
hidden_states: torch.Tensor,
1285+
output: torch.Tensor,
13341286
layer_name: str,
13351287
) -> None:
1336-
"""Fake implementation for torch.compile."""
13371288
return
13381289

13391290

13401291
direct_register_custom_op(
1341-
op_name="gdn_attention_core",
1342-
op_func=gdn_attention_core,
1343-
mutates_args=["core_attn_out"],
1344-
fake_impl=gdn_attention_core_fake,
1292+
op_name="gdn_attention",
1293+
op_func=gdn_attention,
1294+
mutates_args=["output"],
1295+
fake_impl=gdn_attention_fake,
13451296
)
13461297

13471298

0 commit comments

Comments
 (0)