Skip to content

Commit 3253457

Browse files
committed
[AclGraph] Adapt aclgraph into new graph dispatcher arch
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 8fd5399 commit 3253457

File tree

5 files changed

+471
-336
lines changed

5 files changed

+471
-336
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from typing import Any, Optional
55

66
import torch
7-
from vllm.config import VllmConfig
7+
from vllm.config import CUDAGraphMode, VllmConfig
88
from vllm.distributed import (get_dp_group, get_ep_group,
99
get_tensor_model_parallel_world_size)
10-
from vllm.forward_context import get_forward_context, set_forward_context
10+
from vllm.forward_context import (BatchDescriptor, get_forward_context,
11+
set_forward_context)
1112

1213
import vllm_ascend.envs as envs_ascend
1314
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
@@ -48,26 +49,31 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
4849

4950
@contextmanager
5051
def set_ascend_forward_context(
51-
attn_metadata: Any,
52-
vllm_config: VllmConfig,
53-
virtual_engine: int = 0,
54-
num_tokens: Optional[int] = None,
55-
num_tokens_across_dp: Optional[torch.Tensor] = None,
56-
with_prefill: bool = True,
57-
in_profile_run: bool = False,
58-
reserved_mc2_mask: Optional[torch.Tensor] = None,
59-
moe_comm_method: Optional[MoECommMethod] = None,
60-
num_actual_tokens: Optional[int] = None,
61-
):
52+
attn_metadata: Any,
53+
vllm_config: VllmConfig,
54+
virtual_engine: int = 0,
55+
num_tokens: Optional[int] = None,
56+
num_tokens_across_dp: Optional[torch.Tensor] = None,
57+
with_prefill: bool = True,
58+
in_profile_run: bool = False,
59+
reserved_mc2_mask: Optional[torch.Tensor] = None,
60+
moe_comm_method: Optional[MoECommMethod] = None,
61+
num_actual_tokens: Optional[int] = None,
62+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
63+
batch_descriptor: Optional[BatchDescriptor] = None):
6264
"""A context manager that stores the current forward context,
6365
can be attention metadata, etc.
6466
We add some additional param into forward_context.
6567
"""
66-
with set_forward_context(attn_metadata,
67-
vllm_config,
68-
virtual_engine=virtual_engine,
69-
num_tokens=num_tokens,
70-
num_tokens_across_dp=num_tokens_across_dp):
68+
with set_forward_context(
69+
attn_metadata,
70+
vllm_config,
71+
virtual_engine=virtual_engine,
72+
num_tokens=num_tokens,
73+
num_tokens_across_dp=num_tokens_across_dp,
74+
cudagraph_runtime_mode=aclgraph_runtime_mode,
75+
batch_descriptor=batch_descriptor,
76+
):
7177
forward_context = get_forward_context()
7278
forward_context.moe_comm_method = moe_comm_method
7379
forward_context.with_prefill = with_prefill
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import dataclasses
5+
from contextlib import ExitStack
6+
from typing import Any, Callable, Optional
7+
from unittest.mock import patch
8+
9+
import torch
10+
import vllm.envs as envs
11+
from vllm.compilation.counter import compilation_counter
12+
from vllm.compilation.cuda_graph import CUDAGraphOptions
13+
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
14+
from vllm.config import CUDAGraphMode, VllmConfig
15+
from vllm.forward_context import BatchDescriptor, get_forward_context
16+
from vllm.logger import init_logger
17+
from vllm.platforms import current_platform
18+
from vllm.utils import weak_ref_tensors
19+
20+
logger = init_logger(__name__)
21+
22+
23+
@dataclasses.dataclass
24+
class ACLGraphEntry:
25+
batch_descriptor: BatchDescriptor
26+
aclgraph: Optional[torch.npu.NPUGraph] = None
27+
output: Optional[Any] = None
28+
29+
# for aclgraph debugging, track the input addresses
30+
# during capture, and check if they are the same during replay
31+
input_addresses: Optional[list[int]] = None
32+
33+
34+
class ACLGraphWrapper:
35+
"""Wraps a runnable to add acl graph capturing and replaying ability. And
36+
provide attribute access to the underlying `runnable` via `__getattr__`.
37+
38+
The workflow of this wrapper in the aclgraph dispatching is as follows:
39+
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
40+
PIECEWISE).
41+
2. At runtime, the wrapper receives a runtime_mode and a
42+
batch_descriptor(key) from the forward context and blindly trust them
43+
for aclgraph dispatching.
44+
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
45+
wrapper, just call the runnable directly.
46+
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
47+
the wrapper will perform aclgraph capture(if key does not exist, create
48+
a new entry and cache it) or replay (if key exists in the cache).
49+
50+
Note: ACLGraphWrapper does not store persistent buffers or copy any
51+
runtime inputs into that buffers for replay. We assume implementing them
52+
is done outside of the wrapper. That is because we do not make any
53+
assumption on the dynamic shape (batch size) of the runtime inputs, as a
54+
trade-off for staying orthogonal to compilation logic. Nevertheless,
55+
tracing and checking the input addresses to be consistent during replay is
56+
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
57+
"""
58+
59+
def __init__(self,
60+
runnable: Callable,
61+
vllm_config: VllmConfig,
62+
runtime_mode: CUDAGraphMode,
63+
graph_pool: Any = None,
64+
cudagraph_options: Optional[CUDAGraphOptions] = None):
65+
self.runnable = runnable
66+
self.vllm_config = vllm_config
67+
self.graph_pool = graph_pool
68+
self.runtime_mode = runtime_mode
69+
self.compilation_config = vllm_config.compilation_config
70+
71+
self.first_run_finished = False
72+
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
73+
74+
# assert runtime_mode is not NONE(no aclgraph), otherwise, we don't
75+
# need to initialize a ACLGraphWrapper.
76+
assert self.runtime_mode != CUDAGraphMode.NONE
77+
if self.graph_pool is None:
78+
self.graph_pool = current_platform.get_global_graph_pool()
79+
80+
if cudagraph_options is None:
81+
cudagraph_options = CUDAGraphOptions()
82+
self.aclgraph_options = cudagraph_options
83+
# the entries for different batch descriptors that we need to capture
84+
# aclgraphs for.
85+
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\
86+
= {}
87+
88+
def __getattr__(self, key: str):
89+
# allow accessing the attributes of the runnable.
90+
if hasattr(self.runnable, key):
91+
return getattr(self.runnable, key)
92+
raise AttributeError(f"Attribute {key} not exists in the runnable of "
93+
f"aclgraph wrapper: {self.runnable}")
94+
95+
def unwrap(self) -> Callable:
96+
# in case we need to access the original runnable.
97+
return self.runnable
98+
99+
def __call__(self, *args, **kwargs):
100+
forward_context = get_forward_context()
101+
batch_descriptor = forward_context.batch_descriptor
102+
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode
103+
104+
if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
105+
aclgraph_runtime_mode != self.runtime_mode:
106+
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
107+
# running without aclgraphs.
108+
# We do not trigger capture/replay if the runtime mode is not
109+
# matches. This enables properly dispatching to the correct
110+
# CUDAGraphWrapper when nesting multiple instances with different
111+
# runtime modes.
112+
return self.runnable(*args, **kwargs)
113+
114+
if batch_descriptor not in self.concrete_aclgraph_entries:
115+
# create a new entry for this batch descriptor
116+
self.concrete_aclgraph_entries[batch_descriptor] = \
117+
ACLGraphEntry(batch_descriptor=batch_descriptor)
118+
119+
entry = self.concrete_aclgraph_entries[batch_descriptor]
120+
121+
if entry.aclgraph is None:
122+
if self.aclgraph_options.debug_log_enable:
123+
# Since we capture aclgraph for many different shapes and
124+
# capturing is fast, we don't need to log it for every
125+
# shape. E.g. we only log it for the first subgraph in
126+
# piecewise mode.
127+
logger.debug("Capturing a aclgraph on (%s,%s)",
128+
self.runtime_mode.name, entry.batch_descriptor)
129+
# validate that aclgraph capturing is legal at this point.
130+
validate_cudagraph_capturing_enabled()
131+
132+
input_addresses = [
133+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
134+
]
135+
entry.input_addresses = input_addresses
136+
aclgraph = torch.npu.NPUGraph()
137+
138+
with ExitStack() as stack:
139+
if self.aclgraph_options.gc_disable:
140+
# during every model forward for piecewise aclgraph
141+
# mode, we will capture many pieces of aclgraphs
142+
# (roughly one per layer). running gc again and again
143+
# across layers will make the aclgraph capture very slow.
144+
# therefore, we only run gc for the first graph,
145+
# and disable gc for the rest of the graphs.
146+
stack.enter_context(patch("gc.collect", lambda: None))
147+
stack.enter_context(
148+
patch("torch.npu.empty_cache", lambda: None))
149+
150+
# mind-exploding: carefully manage the reference and memory.
151+
with torch.npu.graph(aclgraph, pool=self.graph_pool):
152+
# `output` is managed by pytorch's aclgraph pool
153+
output = self.runnable(*args, **kwargs)
154+
if self.aclgraph_options.weak_ref_output:
155+
# by converting it to weak ref,
156+
# the original `output` will immediately be released
157+
# to save memory. It is only safe to do this for
158+
# the last graph in piecewise aclgraph mode, because
159+
# the output of the last graph will not be used by
160+
# any other acl graph.
161+
output = weak_ref_tensors(output)
162+
163+
# here we always use weak ref for the output
164+
# to save memory
165+
entry.output = weak_ref_tensors(output)
166+
entry.aclgraph = aclgraph
167+
168+
compilation_counter.num_cudagraph_captured += 1
169+
170+
# important: we need to return the output, rather than
171+
# the weak ref of the output, so that pytorch can correctly
172+
# manage the memory during acl graph capture
173+
return output
174+
175+
if self.is_debugging_mode:
176+
# check if the input addresses are the same
177+
new_input_addresses = [
178+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
179+
]
180+
assert new_input_addresses == entry.input_addresses, (
181+
f"Input addresses for aclgraphs are different "
182+
f"during replay. Expected {entry.input_addresses}, "
183+
f"got {new_input_addresses}")
184+
185+
entry.aclgraph.replay()
186+
return entry.output

0 commit comments

Comments
 (0)