-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Perf] Decouple torch op from GDA to leverage torch.compile #27871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| hidden_states: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| output: torch.Tensor, | ||
| ) -> None: | ||
| return torch.ops.vllm.kda_attention( | ||
| hidden_states, | ||
| output, | ||
| ) -> torch.Tensor: | ||
| num_tokens = hidden_states.size(0) | ||
| q = self.q_proj(hidden_states)[0] | ||
| k = self.k_proj(hidden_states)[0] | ||
| v = self.v_proj(hidden_states)[0] | ||
|
|
||
| beta = self.b_proj(hidden_states)[0].float().sigmoid() | ||
| g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0] | ||
| g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias) | ||
| beta = beta.unsqueeze(0) | ||
| g1 = g1.unsqueeze(0) | ||
|
|
||
| g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0] | ||
| g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim) | ||
|
|
||
| core_attn_out = torch.zeros( | ||
| (1, num_tokens, self.local_num_heads, self.head_dim), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
| torch.ops.vllm.kda_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| g1, | ||
| g2, | ||
| beta, | ||
| core_attn_out, | ||
| self.prefix, | ||
| ) | ||
| core_attn_out = self.o_norm(core_attn_out, g2) | ||
| core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") | ||
|
|
||
| return self.o_proj(core_attn_out)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not return new tensor without mutating provided output
The new KimiDeltaAttention.forward now builds and returns a fresh tensor but never writes to the output argument. Callers such as KimiDecoderLayer.forward still allocate an attn_output tensor and invoke self.self_attn(..., output=attn_output) without capturing a return value, relying on the method to mutate output in place. With this change attn_output remains uninitialised and the subsequent hidden_states = attn_output propagates garbage, breaking the layer’s output entirely. Either continue to fill output in place or update callers to use the returned tensor.
Useful? React with 👍 / 👎.
youkaichao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me. cc @zhiyuan1i for verification.
|
LGTM! LGTM! |
I'll check it later |
seems no change on real computations, maybe we could skip bmk. |
Purpose
Perf Test
Prefill
Total Token throughput (tok/s):
Decode
Total Token throughput (tok/s):
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.