Skip to content

Commit d4c7508

Browse files
momo609wangxiaoxin-sherie
andauthored
[Perf] Move attention update stream out of loop to optimize performance (#3848)
### What this PR does / why we need it? In the `update_*attn_params` functions, the `torch.npu.stream(update_stream)` context manager was previously located inside the for-loop that updates parameters for each layer. This resulted in redundant stream initiations for every layer, adding unnecessary overhead. This commit refactors the code by moving the stream context manager to wrap the entire for-loop. This ensures that the update stream is initiated only once per function call, rather than for each layer. This change reduces 90us in each decode model. update stream in every layer: <img width="1720" height="383" alt="image" src="https://github.com/user-attachments/assets/70e4cb69-5bc1-4180-a67d-c99132134be6" /> remove update stream in every layer: <img width="1269" height="175" alt="image" src="https://github.com/user-attachments/assets/0e290edb-b0ce-48fe-b032-1b924ade6ae5" /> ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@83f478b Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
1 parent d0cc9c1 commit d4c7508

File tree

1 file changed

+90
-88
lines changed

1 file changed

+90
-88
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 90 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -194,26 +194,25 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
194194
graph_params = get_graph_params()
195195
# FIXME: Behold! We are using a temporary hack here to update the args
196196
# for each layer's attention op in the graph.
197-
for key, param, handle, event in zip(
198-
forward_context.attn_metadata,
199-
graph_params.attn_params[runtime_shape],
200-
graph_params.handles[runtime_shape],
201-
graph_params.events[runtime_shape],
202-
):
203-
(
204-
query,
205-
key_cache,
206-
value_cache,
207-
num_kv_heads,
208-
num_heads,
209-
scale,
210-
block_table,
211-
seq_lens,
212-
output,
213-
) = param
214-
seq_lens = forward_context.attn_metadata[key].seq_lens
215-
216-
with torch.npu.stream(update_stream):
197+
with torch.npu.stream(update_stream):
198+
for key, param, handle, event in zip(
199+
forward_context.attn_metadata,
200+
graph_params.attn_params[runtime_shape],
201+
graph_params.handles[runtime_shape],
202+
graph_params.events[runtime_shape],
203+
):
204+
(
205+
query,
206+
key_cache,
207+
value_cache,
208+
num_kv_heads,
209+
num_heads,
210+
scale,
211+
block_table,
212+
seq_lens,
213+
output,
214+
) = param
215+
seq_lens = forward_context.attn_metadata[key].seq_lens
217216
torch.npu.graph_task_update_begin(update_stream, handle)
218217
torch_npu._npu_paged_attention(
219218
query=query,
@@ -236,30 +235,32 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
236235
graph_params = get_graph_params()
237236
# FIXME: Behold! We are using a temporary hack here to update the args
238237
# for each layer's attention op in the graph.
239-
for key, param, handle, event in zip(
240-
forward_context.attn_metadata,
241-
graph_params.attn_params[runtime_shape],
242-
graph_params.handles[runtime_shape],
243-
graph_params.events[runtime_shape],
244-
):
245-
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
246-
spec_attn_mask, sparse_mode, scale, block_table, block_size,
247-
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
248-
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
249-
if speculative_config and speculative_config.method == "deepseek_mtp":
250-
actual_seq_lengths = forward_context.attn_metadata[
251-
key].decode.actual_seq_lengths_q
252-
spec_multiple = speculative_config.num_speculative_tokens + 1
253-
seq_lens_list = seq_lens_list + [0] * (
254-
runtime_shape // spec_multiple - len(seq_lens_list))
255-
actual_seq_lengths = [
256-
spec_multiple * (i + 1)
257-
for i in range(runtime_shape // spec_multiple)
258-
]
259-
else:
260-
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
261-
len(seq_lens_list))
262-
with torch.npu.stream(update_stream):
238+
with torch.npu.stream(update_stream):
239+
for key, param, handle, event in zip(
240+
forward_context.attn_metadata,
241+
graph_params.attn_params[runtime_shape],
242+
graph_params.handles[runtime_shape],
243+
graph_params.events[runtime_shape],
244+
):
245+
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
246+
spec_attn_mask, sparse_mode, scale, block_table, block_size,
247+
seq_lens_list, actual_seq_lengths, attn_output,
248+
softmax_lse) = param
249+
seq_lens_list = forward_context.attn_metadata[
250+
key].decode.seq_lens_list
251+
if speculative_config and speculative_config.method == "deepseek_mtp":
252+
actual_seq_lengths = forward_context.attn_metadata[
253+
key].decode.actual_seq_lengths_q
254+
spec_multiple = speculative_config.num_speculative_tokens + 1
255+
seq_lens_list = seq_lens_list + [0] * (
256+
runtime_shape // spec_multiple - len(seq_lens_list))
257+
actual_seq_lengths = [
258+
spec_multiple * (i + 1)
259+
for i in range(runtime_shape // spec_multiple)
260+
]
261+
else:
262+
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
263+
len(seq_lens_list))
263264
torch.npu.graph_task_update_begin(update_stream, handle)
264265

265266
torch_npu.npu_fused_infer_attention_score.out(
@@ -291,26 +292,27 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
291292
graph_params = get_graph_params()
292293
# FIXME: Behold! We are using a temporary hack here to update the args
293294
# for each layer's attention op in the graph.
294-
for key, param, handle, event in zip(
295-
forward_context.attn_metadata,
296-
graph_params.attn_params[runtime_shape],
297-
graph_params.handles[runtime_shape],
298-
graph_params.events[runtime_shape],
299-
):
300-
(q_nope, k_nope, value, num_heads, num_kv_heads, scale, block_table,
301-
block_size, actual_seq_lengths_kv, attn_output, softmax_lse, cp_rank,
302-
dcp_rank, dcp_size) = param
303-
actual_seq_lengths_kv = forward_context.attn_metadata[
304-
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
305-
dcp_rank]
306-
pad_length = runtime_shape - len(actual_seq_lengths_kv)
307-
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
308-
actual_seq_lengths_kv = np.concatenate(
309-
[actual_seq_lengths_kv, pad_tensor])
310-
if dcp_size > 1:
311-
num_heads = num_heads * dcp_size
312-
313-
with torch.npu.stream(update_stream):
295+
with torch.npu.stream(update_stream):
296+
for key, param, handle, event in zip(
297+
forward_context.attn_metadata,
298+
graph_params.attn_params[runtime_shape],
299+
graph_params.handles[runtime_shape],
300+
graph_params.events[runtime_shape],
301+
):
302+
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
303+
block_table, block_size, actual_seq_lengths_kv, attn_output,
304+
softmax_lse, cp_rank, dcp_rank, dcp_size) = param
305+
actual_seq_lengths_kv = forward_context.attn_metadata[
306+
key].decode_meta.num_computed_tokens_of_pcp_dcp[:, cp_rank,
307+
dcp_rank]
308+
pad_length = runtime_shape - len(actual_seq_lengths_kv)
309+
pad_tensor = np.zeros(pad_length,
310+
dtype=actual_seq_lengths_kv.dtype)
311+
actual_seq_lengths_kv = np.concatenate(
312+
[actual_seq_lengths_kv, pad_tensor])
313+
if dcp_size > 1:
314+
num_heads = num_heads * dcp_size
315+
314316
torch.npu.graph_task_update_begin(update_stream, handle)
315317

316318
torch_npu.npu_fused_infer_attention_score.out(
@@ -340,30 +342,30 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
340342
graph_params = get_graph_params()
341343
# FIXME: Behold! We are using a temporary hack here to update the args
342344
# for each layer's attention op in the graph.
343-
for key, param, handle, event in zip(
344-
forward_context.attn_metadata,
345-
graph_params.attn_params[runtime_shape],
346-
graph_params.handles[runtime_shape],
347-
graph_params.events[runtime_shape],
348-
):
349-
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads, scale,
350-
num_kv_heads, attn_output, softmax_lse) = param
351-
352-
decode_meta = forward_context.attn_metadata[key].decode
353-
seq_len = decode_meta.cp_seq_len
354-
355-
if speculative_config and speculative_config.method == "deepseek_mtp":
356-
spec_multiple = speculative_config.num_speculative_tokens + 1
357-
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
358-
len(seq_len))
359-
else:
360-
pad_length = runtime_shape - len(seq_len)
361-
pad_tensor = torch.zeros(pad_length,
362-
dtype=seq_len.dtype,
363-
device=seq_len.device)
364-
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
365-
366-
with torch.npu.stream(update_stream):
345+
with torch.npu.stream(update_stream):
346+
for key, param, handle, event in zip(
347+
forward_context.attn_metadata,
348+
graph_params.attn_params[runtime_shape],
349+
graph_params.handles[runtime_shape],
350+
graph_params.events[runtime_shape],
351+
):
352+
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
353+
scale, num_kv_heads, attn_output, softmax_lse) = param
354+
355+
decode_meta = forward_context.attn_metadata[key].decode
356+
seq_len = decode_meta.cp_seq_len
357+
358+
if speculative_config and speculative_config.method == "deepseek_mtp":
359+
spec_multiple = speculative_config.num_speculative_tokens + 1
360+
seq_len = seq_len + [0] * (runtime_shape // spec_multiple -
361+
len(seq_len))
362+
else:
363+
pad_length = runtime_shape - len(seq_len)
364+
pad_tensor = torch.zeros(pad_length,
365+
dtype=seq_len.dtype,
366+
device=seq_len.device)
367+
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
368+
367369
torch.npu.graph_task_update_begin(update_stream, handle)
368370

369371
torch_npu.atb.npu_multi_head_latent_attention(

0 commit comments

Comments
 (0)