Skip to content

Conversation

@ZJY0516
Copy link
Contributor

@ZJY0516 ZJY0516 commented Oct 31, 2025

Purpose

Perf Test

vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct -tp 4 --trust-remote-code

Prefill

--num-prompts 32 \
--random-input-len 2048 \
--random-output-len 1

Total Token throughput (tok/s):

  • This pr: 51507.38
  • main: 47985.19

Decode

--num-prompts 32 \
--random-input-len 100 \
--random-output-len 1024

Total Token throughput (tok/s):

  • This pr: 4921.65
  • main: 3161.05

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 marked this pull request as ready for review October 31, 2025 08:54
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 259 to +295
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
return torch.ops.vllm.kda_attention(
hidden_states,
output,
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
q = self.q_proj(hidden_states)[0]
k = self.k_proj(hidden_states)[0]
v = self.v_proj(hidden_states)[0]

beta = self.b_proj(hidden_states)[0].float().sigmoid()
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g1 = g1.unsqueeze(0)

g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)

core_attn_out = torch.zeros(
(1, num_tokens, self.local_num_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops.vllm.kda_attention(
q,
k,
v,
g1,
g2,
beta,
core_attn_out,
self.prefix,
)
core_attn_out = self.o_norm(core_attn_out, g2)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")

return self.o_proj(core_attn_out)[0]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Do not return new tensor without mutating provided output

The new KimiDeltaAttention.forward now builds and returns a fresh tensor but never writes to the output argument. Callers such as KimiDecoderLayer.forward still allocate an attn_output tensor and invoke self.self_attn(..., output=attn_output) without capturing a return value, relying on the method to mutate output in place. With this change attn_output remains uninitialised and the subsequent hidden_states = attn_output propagates garbage, breaking the layer’s output entirely. Either continue to fill output in place or update callers to use the returned tensor.

Useful? React with 👍 / 👎.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me. cc @zhiyuan1i for verification.

@zhiyuan1i
Copy link
Contributor

LGTM! LGTM!
have you checked some bmk? such as gsm8k

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Oct 31, 2025

LGTM! LGTM!
have you checked some bmk? such as gsm8k

I'll check it later

@zhiyuan1i
Copy link
Contributor

LGTM! LGTM!
have you checked some bmk? such as gsm8k

I'll check it later

seems no change on real computations, maybe we could skip bmk.
Thank you for your work.

@youkaichao youkaichao enabled auto-merge (squash) October 31, 2025 11:31
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 31, 2025
@youkaichao youkaichao disabled auto-merge October 31, 2025 13:35
@youkaichao youkaichao merged commit 3857eb8 into vllm-project:main Oct 31, 2025
47 of 50 checks passed
@ZJY0516 ZJY0516 deleted the kimi-opt branch November 3, 2025 03:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants