Skip to content

Commit f95fc5a

Browse files
author
闫鹏全
committed
support aclgraph for vllm v1 engine
1 parent c7f6584 commit f95fc5a

File tree

4 files changed

+165
-58
lines changed

4 files changed

+165
-58
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 95 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2424
AttentionLayer, AttentionType)
2525
from vllm.attention.backends.utils import CommonAttentionState
26-
26+
from vllm.utils import direct_register_custom_op
27+
from vllm.forward_context import ForwardContext, get_forward_context
2728

2829
class AscendAttentionBackend(AttentionBackend):
30+
accept_output_buffer: bool = True
2931

3032
@staticmethod
3133
def get_name() -> str:
@@ -150,6 +152,7 @@ def forward(
150152
kv_cache: torch.Tensor,
151153
attn_metadata: AscendMetadata,
152154
output: Optional[torch.Tensor] = None,
155+
trace_flag: bool = True,
153156
) -> torch.Tensor:
154157
"""Forward pass with Ascend attention.
155158
Args:
@@ -167,59 +170,100 @@ def forward(
167170
shape = [batch_size * seq_len, num_heads, head_size]
168171
"""
169172
num_tokens = query.shape[0]
170-
output = torch.empty(num_tokens,
173+
if output is None:
174+
output = torch.empty(num_tokens,
171175
self.num_heads,
172176
self.head_size,
173177
dtype=query.dtype,
174178
device=query.device)
175-
176-
if attn_metadata is None:
177-
# Profiling run.
178-
return output.view(num_tokens, self.hidden_size)
179-
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
180-
attn_type = self.attn_type
181-
if attn_type != AttentionType.DECODER:
182-
raise NotImplementedError("Encoder self-attention and "
183-
"encoder/decoder cross-attention "
184-
"are not implemented for "
185-
"PallasAttentionBackendImpl")
186-
# View q k v to BSH.
187-
query = query.view(-1, self.num_heads, self.head_size)
188-
key = key.view(-1, self.num_kv_heads, self.head_size)
189-
value = value.view(-1, self.num_kv_heads, self.head_size)
190-
# TODO: Remove this contiguous in the future.
191-
value = value.contiguous()
192-
193-
if hasattr(layer, 'quant_method'):
194-
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
195-
pass
196-
else:
197-
if kv_cache.numel() > 0:
198-
key_cache, value_cache = kv_cache[0], kv_cache[1]
199-
num_blocks, block_size, _ = key_cache.shape
200-
key_cache = key_cache.view(num_blocks, block_size,
201-
self.num_kv_heads, self.head_size)
202-
value_cache = value_cache.view(num_blocks, block_size,
203-
self.num_kv_heads,
204-
self.head_size)
205-
slots = attn_metadata.slot_mapping
206-
torch_npu._npu_reshape_and_cache(key=key,
207-
value=value,
208-
key_cache=key_cache,
209-
value_cache=value_cache,
210-
slot_indices=slots)
211-
212-
# use paged attention
213-
torch_npu._npu_paged_attention_splitfuse(
179+
if trace_flag:
180+
torch.ops.vllm.unified_ascend_attention_with_output(
214181
query=query,
215-
key_cache=key_cache,
216-
value_cache=value_cache,
217-
mask=attn_metadata.attn_mask,
218-
block_table=attn_metadata.block_tables,
219-
seq_len=attn_metadata.seq_lens,
220-
context_lens=attn_metadata.context_lens,
221-
num_kv_heads=self.num_kv_heads,
222-
num_heads=self.num_heads,
223-
scale_value=self.scale,
224-
out=output)
182+
key=key,
183+
value=value,
184+
output=output,
185+
layer_name=layer.layer_name
186+
)
187+
else:
188+
num_tokens = query.shape[0]
189+
if attn_metadata is None:
190+
return output.view(num_tokens, self.hidden_size)
191+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
192+
attn_type = self.attn_type
193+
if attn_type != AttentionType.DECODER:
194+
raise NotImplementedError("Encoder self-attention and "
195+
"encoder/decoder cross-attention "
196+
"are not implemented for "
197+
"PallasAttentionBackendImpl")
198+
# View q k v to BSH.
199+
query = query.view(-1, self.num_heads, self.head_size)
200+
key = key.view(-1, self.num_kv_heads, self.head_size)
201+
value = value.view(-1, self.num_kv_heads, self.head_size)
202+
# TODO: Remove this contiguous in the future.
203+
value = value.contiguous()
204+
205+
if hasattr(layer, 'quant_method'):
206+
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
207+
pass
208+
else:
209+
if kv_cache.numel() > 0:
210+
key_cache, value_cache = kv_cache[0], kv_cache[1]
211+
num_blocks, block_size, _ = key_cache.shape
212+
key_cache = key_cache.view(num_blocks, block_size,
213+
self.num_kv_heads, self.head_size)
214+
value_cache = value_cache.view(num_blocks, block_size,
215+
self.num_kv_heads,
216+
self.head_size)
217+
slots = attn_metadata.slot_mapping
218+
torch_npu._npu_reshape_and_cache(key=key,
219+
value=value,
220+
key_cache=key_cache,
221+
value_cache=value_cache,
222+
slot_indices=slots)
223+
# use paged attention
224+
torch_npu._npu_paged_attention_splitfuse(
225+
query=query,
226+
key_cache=key_cache,
227+
value_cache=value_cache,
228+
mask=attn_metadata.attn_mask,
229+
block_table=attn_metadata.block_tables,
230+
seq_len=attn_metadata.seq_lens,
231+
context_lens=attn_metadata.context_lens,
232+
num_kv_heads=self.num_kv_heads,
233+
num_heads=self.num_heads,
234+
scale_value=self.scale,
235+
out=output)
225236
return output.view(num_tokens, self.hidden_size)
237+
238+
239+
def unified_ascend_attention_with_output(
240+
query: torch.Tensor,
241+
key: torch.Tensor,
242+
value: torch.Tensor,
243+
output: torch.Tensor,
244+
layer_name: str,
245+
) -> None:
246+
forward_context: ForwardContext = get_forward_context()
247+
attn_metadata = forward_context.attn_metadata
248+
self = forward_context.no_compile_layers[layer_name]
249+
kv_cache = self.kv_cache[forward_context.virtual_engine]
250+
self.impl.forward(self, query, key, value, kv_cache, attn_metadata, output, trace_flag=False)
251+
return
252+
253+
def unified_attention_with_output_fake(
254+
query: torch.Tensor,
255+
key: torch.Tensor,
256+
value: torch.Tensor,
257+
output: torch.Tensor,
258+
layer_name: str,
259+
) -> None:
260+
return
261+
262+
263+
direct_register_custom_op(
264+
op_name="unified_ascend_attention_with_output",
265+
op_func=unified_ascend_attention_with_output,
266+
mutates_args=["output"],
267+
fake_impl=unified_attention_with_output_fake,
268+
dispatch_key="PrivateUse1",
269+
)

vllm_ascend/platform.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
"Warning: Failed to register custom ops, all custom ops will be disabled"
3737
)
3838

39+
from vllm.config import CompilationLevel, VllmConfig, ModelConfig
40+
from vllm.logger import init_logger
3941
from vllm.platforms import Platform, PlatformEnum
4042

4143
if TYPE_CHECKING:
@@ -104,9 +106,18 @@ def mem_get_info(cls) -> Tuple[int, int]:
104106
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
105107
from vllm.config import CompilationLevel # noqa: E402
106108
compilation_config = vllm_config.compilation_config
107-
if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
109+
import os
110+
aclgraph_enabled = os.getenv('ENABLE_ACLGRAPH')
111+
112+
if aclgraph_enabled == '1' and compilation_config and compilation_config.level == CompilationLevel.PIECEWISE:
113+
logger.warning(
114+
"Compilation level %s is supported on NPU now, But use_inductor is no support",
115+
compilation_config.level)
116+
compilation_config.use_inductor = False
117+
compilation_config.splitting_ops = ["vllm.unified_ascend_attention_with_output"]
118+
elif compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
108119
logger.warning(
109-
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
120+
"ENABLE_ACLGRAPH is not set, Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
110121
compilation_config.level)
111122
compilation_config.level = CompilationLevel.NO_COMPILATION
112123

vllm_ascend/worker/model_runner_v1.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import gc
2121
import os
22+
import time
2223
from typing import TYPE_CHECKING, Dict, List, Optional, Union
2324

2425
import numpy as np
@@ -27,8 +28,8 @@
2728
import torch.nn as nn
2829
from vllm.attention import AttentionType
2930
from vllm.attention.layer import Attention
30-
from vllm.config import VllmConfig
31-
from vllm.distributed.parallel_state import get_pp_group
31+
from vllm.config import VllmConfig, CompilationLevel
32+
from vllm.distributed.parallel_state import get_pp_group, graph_capture
3233
from vllm.forward_context import set_forward_context
3334
from vllm.inputs import INPUT_REGISTRY
3435
from vllm.logger import logger
@@ -42,6 +43,13 @@
4243
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
4344
KVCacheSpec)
4445
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
46+
from vllm.triton_utils import HAS_TRITON
47+
if HAS_TRITON:
48+
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
49+
else:
50+
INVALID_TOKEN_ID = None
51+
RejectionSampler = None
52+
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
4553
from vllm.v1.utils import bind_kv_cache
4654
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
4755

@@ -171,6 +179,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
171179
self.input_positions_cpu = torch.arange(0,
172180
self.max_num_tokens,
173181
device="cpu")
182+
self.use_cuda_graph = (self.vllm_config.compilation_config.level
183+
== CompilationLevel.PIECEWISE
184+
and not self.model_config.enforce_eager)
185+
self.cudagraph_batch_sizes = list(
186+
reversed(
187+
self.vllm_config.compilation_config.cudagraph_capture_sizes))
174188

175189
# NOTE: Pre-construct a mask matrix to improve the efficiency of
176190
# attention mask construction during inference.
@@ -627,7 +641,7 @@ def _dummy_run(self) -> torch.Tensor:
627641
if self.uses_mrope:
628642
positions = self.mrope_positions[:, :self.max_num_tokens]
629643
else:
630-
positions = self.input_positions_cpu[:self.max_num_tokens]
644+
positions = self.positions[:self.max_num_tokens]
631645

632646
if get_pp_group().is_first_rank:
633647
intermediate_tensors = None
@@ -645,7 +659,7 @@ def _dummy_run(self) -> torch.Tensor:
645659

646660
with set_forward_context(None, self.vllm_config):
647661
hidden_states = model(input_ids=input_ids,
648-
positions=positions.to(self.device),
662+
positions=positions,
649663
intermediate_tensors=intermediate_tensors,
650664
inputs_embeds=inputs_embeds)
651665
return hidden_states
@@ -779,3 +793,32 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
779793
f"Unknown attention type: {attn_module.attn_type}")
780794

781795
return kv_cache_spec
796+
797+
798+
def capture_model(self) -> None:
799+
if not self.use_cuda_graph:
800+
logger.warning(
801+
"Skipping NPU graph capture. Please add "
802+
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
803+
return
804+
805+
start_time = time.perf_counter()
806+
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
807+
808+
# Trigger CUDA graph capture for specific shapes.
809+
# Capture the large shapes first so that the smaller shapes
810+
# can reuse the memory pool allocated for the large shapes.
811+
with graph_capture(device=self.device):
812+
for num_tokens in reversed(self.cudagraph_batch_sizes):
813+
for _ in range(self.vllm_config.compilation_config.
814+
cudagraph_num_of_warmups):
815+
self._dummy_run(num_tokens)
816+
self._dummy_run(num_tokens)
817+
818+
end_time = time.perf_counter()
819+
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
820+
elapsed_time = end_time - start_time
821+
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
822+
# This usually takes 5~20 seconds.
823+
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
824+
elapsed_time, cuda_graph_size / (1 << 30))

vllm_ascend/worker/worker_v1.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,17 @@ def load_model(self) -> None:
158158
self.model_runner.load_model()
159159

160160
def compile_or_warm_up_model(self) -> None:
161+
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
161162
if not self.model_config.enforce_eager:
162-
logger.warning("Graph capture is not supported on NPU.")
163+
warmup_sizes = [
164+
x for x in warmup_sizes if x not in
165+
self.vllm_config.compilation_config.cudagraph_capture_sizes
166+
]
167+
for size in sorted(warmup_sizes, reverse=True):
168+
logger.info("Compile and warming up model for size %d", size)
169+
self.model_runner._dummy_run(size)
170+
if not self.model_config.enforce_eager:
171+
self.model_runner.capture_model()
163172
# Reset the seed to ensure that the random state is not affected by
164173
# the model initialization and profiling.
165174
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)