Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions vllm/model_executor/layers/kda.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,36 @@


def kda_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output)
self._forward(
q_proj_states=q_proj_states,
k_proj_states=k_proj_states,
v_proj_states=v_proj_states,
g1=g1,
g2=g2,
beta=beta,
core_attn_out=core_attn_out,
)


def kda_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
layer_name: str,
) -> None:
return
Expand All @@ -60,7 +78,7 @@ def kda_attention_fake(
direct_register_custom_op(
op_name="kda_attention",
op_func=kda_attention,
mutates_args=["output"],
mutates_args=["core_attn_out"],
fake_impl=kda_attention_fake,
)

Expand Down Expand Up @@ -241,37 +259,56 @@ def forward(
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
return torch.ops.vllm.kda_attention(
hidden_states,
output,
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0]
v = self.v_proj(hidden_states)[0]

beta = self.b_proj(hidden_states)[0].float().sigmoid()
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g1 = g1.unsqueeze(0)

g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)

core_attn_out = torch.zeros(
(1, num_tokens, self.local_num_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.kda_attention(
q,
k,
v,
g1,
g2,
beta,
core_attn_out,
self.prefix,
)
core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")

return self.o_proj(core_attn_out)[0]
Comment on lines 259 to +295

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Do not return new tensor without mutating provided output

The new KimiDeltaAttention.forward now builds and returns a fresh tensor but never writes to the output argument. Callers such as KimiDecoderLayer.forward still allocate an attn_output tensor and invoke self.self_attn(..., output=attn_output) without capturing a return value, relying on the method to mutate output in place. With this change attn_output remains uninitialised and the subsequent hidden_states = attn_output propagates garbage, breaking the layer’s output entirely. Either continue to fill output in place or update callers to use the returned tensor.

Useful? React with 👍 / 👎.


def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
q_proj_states: torch.Tensor,
k_proj_states: torch.Tensor,
v_proj_states: torch.Tensor,
g1: torch.Tensor,
g2: torch.Tensor,
beta: torch.Tensor,
core_attn_out: torch.Tensor,
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata

if attn_metadata is None:
# V1 profile run
# Mimic the memory allocation in the real run
q = torch.empty_like(hidden_states)
k = torch.empty_like(hidden_states)
v = torch.empty_like(hidden_states)
g = hidden_states.new_empty(
hidden_states.size(0),
self.local_num_heads,
self.head_dim,
dtype=torch.float32,
)
beta = torch.empty(
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
)
core_attn_out = torch.empty_like(hidden_states)
# # V1 profile run
return

assert isinstance(attn_metadata, dict)
Expand All @@ -288,10 +325,6 @@ def _forward(
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)

q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]

q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
Expand Down Expand Up @@ -374,14 +407,6 @@ def _forward(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
)

beta = self.b_proj(hidden_states)[0].float().sigmoid()

g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)

beta = beta.unsqueeze(0)
g = g.unsqueeze(0)

if attn_metadata.num_prefills > 0:
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0
Expand All @@ -393,7 +418,7 @@ def _forward(
q=q,
k=k,
v=v,
g=g,
g=g1,
beta=beta,
initial_state=initial_state,
output_final_state=True,
Expand All @@ -410,17 +435,12 @@ def _forward(
q=q,
k=k,
v=v,
g=g,
g=g1,
beta=beta,
initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
)

g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")

output[:] = self.o_proj(core_attn_out)[0]
assert core_attn_out_non_spec.shape == core_attn_out.shape
core_attn_out[:] = core_attn_out_non_spec