Skip to content

Commit 7df6263

Browse files
author
闫鹏全
committed
support aclgraph
1 parent c42e21a commit 7df6263

File tree

5 files changed

+183
-56
lines changed

5 files changed

+183
-56
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 95 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2424
AttentionLayer, AttentionType)
2525
from vllm.attention.backends.utils import CommonAttentionState
26-
26+
from vllm.utils import direct_register_custom_op
27+
from vllm.forward_context import ForwardContext, get_forward_context
2728

2829
class AscendAttentionBackend(AttentionBackend):
30+
accept_output_buffer: bool = True
2931

3032
@staticmethod
3133
def get_name() -> str:
@@ -150,6 +152,7 @@ def forward(
150152
kv_cache: torch.Tensor,
151153
attn_metadata: AscendMetadata,
152154
output: Optional[torch.Tensor] = None,
155+
trace_flag: bool = True,
153156
) -> torch.Tensor:
154157
"""Forward pass with Ascend attention.
155158
Args:
@@ -167,59 +170,100 @@ def forward(
167170
shape = [batch_size * seq_len, num_heads, head_size]
168171
"""
169172
num_tokens = query.shape[0]
170-
output = torch.empty(num_tokens,
173+
if output is None:
174+
output = torch.empty(num_tokens,
171175
self.num_heads,
172176
self.head_size,
173177
dtype=query.dtype,
174178
device=query.device)
175-
176-
if attn_metadata is None:
177-
# Profiling run.
178-
return output.view(num_tokens, self.hidden_size)
179-
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
180-
attn_type = self.attn_type
181-
if attn_type != AttentionType.DECODER:
182-
raise NotImplementedError("Encoder self-attention and "
183-
"encoder/decoder cross-attention "
184-
"are not implemented for "
185-
"PallasAttentionBackendImpl")
186-
# View q k v to BSH.
187-
query = query.view(-1, self.num_heads, self.head_size)
188-
key = key.view(-1, self.num_kv_heads, self.head_size)
189-
value = value.view(-1, self.num_kv_heads, self.head_size)
190-
# TODO: Remove this contiguous in the future.
191-
value = value.contiguous()
192-
193-
if hasattr(layer, 'quant_method'):
194-
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
195-
pass
196-
else:
197-
if kv_cache.numel() > 0:
198-
key_cache, value_cache = kv_cache[0], kv_cache[1]
199-
num_blocks, block_size, _ = key_cache.shape
200-
key_cache = key_cache.view(num_blocks, block_size,
201-
self.num_kv_heads, self.head_size)
202-
value_cache = value_cache.view(num_blocks, block_size,
203-
self.num_kv_heads,
204-
self.head_size)
205-
slots = attn_metadata.slot_mapping
206-
torch_npu._npu_reshape_and_cache(key=key,
207-
value=value,
208-
key_cache=key_cache,
209-
value_cache=value_cache,
210-
slot_indices=slots)
211-
212-
# use paged attention
213-
torch_npu._npu_paged_attention_splitfuse(
179+
if trace_flag:
180+
torch.ops.vllm.unified_ascend_attention_with_output(
214181
query=query,
215-
key_cache=key_cache,
216-
value_cache=value_cache,
217-
mask=attn_metadata.attn_mask,
218-
block_table=attn_metadata.block_tables,
219-
seq_len=attn_metadata.seq_lens,
220-
context_lens=attn_metadata.context_lens,
221-
num_kv_heads=self.num_kv_heads,
222-
num_heads=self.num_heads,
223-
scale_value=self.scale,
224-
out=output)
182+
key=key,
183+
value=value,
184+
output=output,
185+
layer_name=self.name
186+
)
187+
else:
188+
num_tokens = query.shape[0]
189+
if attn_metadata is None:
190+
return output.view(num_tokens, self.hidden_size)
191+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
192+
attn_type = self.attn_type
193+
if attn_type != AttentionType.DECODER:
194+
raise NotImplementedError("Encoder self-attention and "
195+
"encoder/decoder cross-attention "
196+
"are not implemented for "
197+
"PallasAttentionBackendImpl")
198+
# View q k v to BSH.
199+
query = query.view(-1, self.num_heads, self.head_size)
200+
key = key.view(-1, self.num_kv_heads, self.head_size)
201+
value = value.view(-1, self.num_kv_heads, self.head_size)
202+
# TODO: Remove this contiguous in the future.
203+
value = value.contiguous()
204+
205+
if hasattr(layer, 'quant_method'):
206+
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
207+
pass
208+
else:
209+
if kv_cache.numel() > 0:
210+
key_cache, value_cache = kv_cache[0], kv_cache[1]
211+
num_blocks, block_size, _ = key_cache.shape
212+
key_cache = key_cache.view(num_blocks, block_size,
213+
self.num_kv_heads, self.head_size)
214+
value_cache = value_cache.view(num_blocks, block_size,
215+
self.num_kv_heads,
216+
self.head_size)
217+
slots = attn_metadata.slot_mapping
218+
torch_npu._npu_reshape_and_cache(key=key,
219+
value=value,
220+
key_cache=key_cache,
221+
value_cache=value_cache,
222+
slot_indices=slots)
223+
# use paged attention
224+
torch_npu._npu_paged_attention_splitfuse(
225+
query=query,
226+
key_cache=key_cache,
227+
value_cache=value_cache,
228+
mask=attn_metadata.attn_mask,
229+
block_table=attn_metadata.block_tables,
230+
seq_len=attn_metadata.seq_lens,
231+
context_lens=attn_metadata.context_lens,
232+
num_kv_heads=self.num_kv_heads,
233+
num_heads=self.num_heads,
234+
scale_value=self.scale,
235+
out=output)
225236
return output.view(num_tokens, self.hidden_size)
237+
238+
239+
def unified_ascend_attention_with_output(
240+
query: torch.Tensor,
241+
key: torch.Tensor,
242+
value: torch.Tensor,
243+
output: torch.Tensor,
244+
layer_name: str,
245+
) -> None:
246+
forward_context: ForwardContext = get_forward_context()
247+
attn_metadata = forward_context.attn_metadata
248+
self = forward_context.no_compile_layers[layer_name]
249+
kv_cache = self.kv_cache[forward_context.virtual_engine]
250+
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata, trace_flag=False)
251+
252+
253+
def unified_attention_with_output_fake(
254+
query: torch.Tensor,
255+
key: torch.Tensor,
256+
value: torch.Tensor,
257+
output: torch.Tensor,
258+
layer_name: str,
259+
) -> None:
260+
return
261+
262+
263+
direct_register_custom_op(
264+
op_name="unified_ascend_attention_with_output",
265+
op_func=unified_ascend_attention_with_output,
266+
mutates_args=["output"],
267+
fake_impl=unified_attention_with_output_fake,
268+
dispatch_key="PrivateUse1",
269+
)

vllm_ascend/envs.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
from typing import Any, Callable, Dict
3+
4+
env_variables: Dict[str, Callable[[], Any]] = {
5+
# max compile thread num
6+
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
7+
"CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
8+
"COMPILE_CUSTOM_KERNELS":
9+
lambda: os.getenv("COMPILE_CUSTOM_KERNELS", None),
10+
# If set, vllm-ascend will print verbose logs during compliation
11+
"VERBOSE": lambda: bool(int(os.getenv('VERBOSE', '0'))),
12+
"ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None),
13+
"LD_LIBRARY_PATH": lambda: os.getenv("LD_LIBRARY_PATH", None),
14+
}
15+
16+
17+
def __getattr__(name: str):
18+
# lazy evaluation of environment variables
19+
if name in env_variables:
20+
return env_variables[name]()
21+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
22+
23+
24+
def __dir__():
25+
return list(env_variables.keys())

vllm_ascend/platform.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
import torch_npu # noqa: F401
2323
import vllm.envs as envs
24-
from vllm.config import CompilationLevel
24+
from vllm.config import CompilationLevel, VllmConfig, ModelConfig
2525
from vllm.logger import init_logger
2626
from vllm.platforms import Platform, PlatformEnum
2727

@@ -92,7 +92,13 @@ def mem_get_info(cls) -> Tuple[int, int]:
9292
@classmethod
9393
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9494
compilation_config = vllm_config.compilation_config
95-
if compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
95+
if compilation_config and compilation_config.level == CompilationLevel.PIECEWISE:
96+
logger.warning(
97+
"Compilation level %s is supported on NPU now, But use_inductor is no support",
98+
compilation_config.level)
99+
compilation_config.use_inductor = False
100+
compilation_config.splitting_ops = ["vllm.unified_ascend_attention_with_output"]
101+
elif compilation_config and compilation_config.level != CompilationLevel.NO_COMPILATION:
96102
logger.warning(
97103
"Compilation level %s is not supported on NPU now, forcing compilation level to NO_COMPILATION",
98104
compilation_config.level)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#
1919

2020
import gc
21+
import time
2122
from typing import TYPE_CHECKING, Dict, List, Optional, Union
2223

2324
import numpy as np
@@ -26,8 +27,8 @@
2627
import torch.nn as nn
2728
from vllm.attention import AttentionType
2829
from vllm.attention.layer import Attention
29-
from vllm.config import VllmConfig
30-
from vllm.distributed.parallel_state import get_pp_group
30+
from vllm.config import VllmConfig, CompilationLevel
31+
from vllm.distributed.parallel_state import get_pp_group, graph_capture
3132
from vllm.forward_context import set_forward_context
3233
from vllm.inputs import INPUT_REGISTRY
3334
from vllm.logger import init_logger
@@ -43,6 +44,13 @@
4344
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
4445
KVCacheSpec)
4546
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
47+
from vllm.triton_utils import HAS_TRITON
48+
if HAS_TRITON:
49+
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
50+
else:
51+
INVALID_TOKEN_ID = None
52+
RejectionSampler = None
53+
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
4654
from vllm.v1.utils import bind_kv_cache
4755
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
4856

@@ -209,6 +217,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
209217
self.input_positions_cpu = torch.arange(0,
210218
self.max_num_tokens,
211219
device="cpu")
220+
self.use_cuda_graph = (self.vllm_config.compilation_config.level
221+
== CompilationLevel.PIECEWISE
222+
and not self.model_config.enforce_eager)
223+
self.cudagraph_batch_sizes = list(
224+
reversed(
225+
self.vllm_config.compilation_config.cudagraph_capture_sizes))
212226

213227
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
214228
"""Update the cached states and the persistent batch with the scheduler
@@ -790,3 +804,32 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
790804
f"Unknown attention type: {attn_module.attn_type}")
791805

792806
return kv_cache_spec
807+
808+
809+
def capture_model(self) -> None:
810+
if not self.use_cuda_graph:
811+
logger.warning(
812+
"Skipping NPU graph capture. Please add "
813+
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
814+
return
815+
816+
start_time = time.perf_counter()
817+
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
818+
819+
# Trigger CUDA graph capture for specific shapes.
820+
# Capture the large shapes first so that the smaller shapes
821+
# can reuse the memory pool allocated for the large shapes.
822+
with graph_capture(device=self.device):
823+
for num_tokens in reversed(self.cudagraph_batch_sizes):
824+
for _ in range(self.vllm_config.compilation_config.
825+
cudagraph_num_of_warmups):
826+
self._dummy_run(num_tokens)
827+
self._dummy_run(num_tokens)
828+
829+
end_time = time.perf_counter()
830+
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
831+
elapsed_time = end_time - start_time
832+
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
833+
# This usually takes 5~20 seconds.
834+
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
835+
elapsed_time, cuda_graph_size / (1 << 30))

vllm_ascend/worker/worker_v1.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,17 @@ def load_model(self) -> None:
201201
self.model_runner.load_model()
202202

203203
def compile_or_warm_up_model(self) -> None:
204+
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
204205
if not self.model_config.enforce_eager:
205-
logger.warning("Graph capture is not supported on NPU.")
206+
warmup_sizes = [
207+
x for x in warmup_sizes if x not in
208+
self.vllm_config.compilation_config.cudagraph_capture_sizes
209+
]
210+
for size in sorted(warmup_sizes, reverse=True):
211+
logger.info("Compile and warming up model for size %d", size)
212+
self.model_runner._dummy_run(size)
213+
if not self.model_config.enforce_eager:
214+
self.model_runner.capture_model()
206215
# Reset the seed to ensure that the random state is not affected by
207216
# the model initialization and profiling.
208217
set_random_seed(self.model_config.seed)

0 commit comments

Comments
 (0)