Skip to content

Commit 3857eb8

Browse files
authored
[Perf] Decouple torch op from GDA to leverage torch.compile (vllm-project#27871)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent 933cdea commit 3857eb8

File tree

1 file changed

+68
-48
lines changed
  • vllm/model_executor/layers

1 file changed

+68
-48
lines changed

vllm/model_executor/layers/kda.py

Lines changed: 68 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,36 @@
4040

4141

4242
def kda_attention(
43-
hidden_states: torch.Tensor,
44-
output: torch.Tensor,
43+
q_proj_states: torch.Tensor,
44+
k_proj_states: torch.Tensor,
45+
v_proj_states: torch.Tensor,
46+
g1: torch.Tensor,
47+
g2: torch.Tensor,
48+
beta: torch.Tensor,
49+
core_attn_out: torch.Tensor,
4550
layer_name: str,
4651
) -> None:
4752
forward_context: ForwardContext = get_forward_context()
4853
self = forward_context.no_compile_layers[layer_name]
49-
self._forward(hidden_states=hidden_states, output=output)
54+
self._forward(
55+
q_proj_states=q_proj_states,
56+
k_proj_states=k_proj_states,
57+
v_proj_states=v_proj_states,
58+
g1=g1,
59+
g2=g2,
60+
beta=beta,
61+
core_attn_out=core_attn_out,
62+
)
5063

5164

5265
def kda_attention_fake(
53-
hidden_states: torch.Tensor,
54-
output: torch.Tensor,
66+
q_proj_states: torch.Tensor,
67+
k_proj_states: torch.Tensor,
68+
v_proj_states: torch.Tensor,
69+
g1: torch.Tensor,
70+
g2: torch.Tensor,
71+
beta: torch.Tensor,
72+
core_attn_out: torch.Tensor,
5573
layer_name: str,
5674
) -> None:
5775
return
@@ -60,7 +78,7 @@ def kda_attention_fake(
6078
direct_register_custom_op(
6179
op_name="kda_attention",
6280
op_func=kda_attention,
63-
mutates_args=["output"],
81+
mutates_args=["core_attn_out"],
6482
fake_impl=kda_attention_fake,
6583
)
6684

@@ -241,37 +259,56 @@ def forward(
241259
hidden_states: torch.Tensor,
242260
positions: torch.Tensor,
243261
output: torch.Tensor,
244-
) -> None:
245-
return torch.ops.vllm.kda_attention(
246-
hidden_states,
247-
output,
262+
) -> torch.Tensor:
263+
num_tokens = hidden_states.size(0)
264+
q = self.q_proj(hidden_states)[0]
265+
k = self.k_proj(hidden_states)[0]
266+
v = self.v_proj(hidden_states)[0]
267+
268+
beta = self.b_proj(hidden_states)[0].float().sigmoid()
269+
g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
270+
g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
271+
beta = beta.unsqueeze(0)
272+
g1 = g1.unsqueeze(0)
273+
274+
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
275+
g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
276+
277+
core_attn_out = torch.zeros(
278+
(1, num_tokens, self.local_num_heads, self.head_dim),
279+
dtype=hidden_states.dtype,
280+
device=hidden_states.device,
281+
)
282+
torch.ops.vllm.kda_attention(
283+
q,
284+
k,
285+
v,
286+
g1,
287+
g2,
288+
beta,
289+
core_attn_out,
248290
self.prefix,
249291
)
292+
core_attn_out = self.o_norm(core_attn_out, g2)
293+
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
294+
295+
return self.o_proj(core_attn_out)[0]
250296

251297
def _forward(
252298
self,
253-
hidden_states: torch.Tensor,
254-
output: torch.Tensor,
299+
q_proj_states: torch.Tensor,
300+
k_proj_states: torch.Tensor,
301+
v_proj_states: torch.Tensor,
302+
g1: torch.Tensor,
303+
g2: torch.Tensor,
304+
beta: torch.Tensor,
305+
core_attn_out: torch.Tensor,
255306
) -> None:
256307
forward_context = get_forward_context()
257308
attn_metadata: AttentionMetadata = forward_context.attn_metadata
258309

259310
if attn_metadata is None:
260-
# V1 profile run
261-
# Mimic the memory allocation in the real run
262-
q = torch.empty_like(hidden_states)
263-
k = torch.empty_like(hidden_states)
264-
v = torch.empty_like(hidden_states)
265-
g = hidden_states.new_empty(
266-
hidden_states.size(0),
267-
self.local_num_heads,
268-
self.head_dim,
269-
dtype=torch.float32,
270-
)
271-
beta = torch.empty(
272-
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
273-
)
274-
core_attn_out = torch.empty_like(hidden_states)
311+
# # V1 profile run
275312
return
276313

277314
assert isinstance(attn_metadata, dict)
@@ -288,10 +325,6 @@ def _forward(
288325
conv_state_k = conv_state_k.transpose(-1, -2)
289326
conv_state_v = conv_state_v.transpose(-1, -2)
290327

291-
q_proj_states = self.q_proj(hidden_states)[0]
292-
k_proj_states = self.k_proj(hidden_states)[0]
293-
v_proj_states = self.v_proj(hidden_states)[0]
294-
295328
q_conv_weights = self.q_conv1d.weight.view(
296329
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
297330
)
@@ -374,14 +407,6 @@ def _forward(
374407
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
375408
)
376409

377-
beta = self.b_proj(hidden_states)[0].float().sigmoid()
378-
379-
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
380-
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
381-
382-
beta = beta.unsqueeze(0)
383-
g = g.unsqueeze(0)
384-
385410
if attn_metadata.num_prefills > 0:
386411
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
387412
recurrent_state[zero_idx] = 0
@@ -393,7 +418,7 @@ def _forward(
393418
q=q,
394419
k=k,
395420
v=v,
396-
g=g,
421+
g=g1,
397422
beta=beta,
398423
initial_state=initial_state,
399424
output_final_state=True,
@@ -410,17 +435,12 @@ def _forward(
410435
q=q,
411436
k=k,
412437
v=v,
413-
g=g,
438+
g=g1,
414439
beta=beta,
415440
initial_state=recurrent_state,
416441
use_qk_l2norm_in_kernel=True,
417442
cu_seqlens=non_spec_query_start_loc,
418443
ssm_state_indices=non_spec_state_indices_tensor,
419444
)
420-
421-
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
422-
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
423-
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
424-
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
425-
426-
output[:] = self.o_proj(core_attn_out)[0]
445+
assert core_attn_out_non_spec.shape == core_attn_out.shape
446+
core_attn_out[:] = core_attn_out_non_spec

0 commit comments

Comments
 (0)