|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import dataclasses |
| 5 | +from contextlib import ExitStack |
| 6 | +from typing import Any, Callable, Optional |
| 7 | +from unittest.mock import patch |
| 8 | + |
| 9 | +import torch |
| 10 | +import vllm.envs as envs |
| 11 | +from vllm.compilation.counter import compilation_counter |
| 12 | +from vllm.compilation.cuda_graph import CUDAGraphOptions |
| 13 | +from vllm.compilation.monitor import validate_cudagraph_capturing_enabled |
| 14 | +from vllm.config import CUDAGraphMode, VllmConfig |
| 15 | +from vllm.forward_context import BatchDescriptor, get_forward_context |
| 16 | +from vllm.logger import init_logger |
| 17 | +from vllm.platforms import current_platform |
| 18 | +from vllm.utils import weak_ref_tensors |
| 19 | + |
| 20 | +logger = init_logger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +@dataclasses.dataclass |
| 24 | +class ACLGraphEntry: |
| 25 | + batch_descriptor: BatchDescriptor |
| 26 | + aclgraph: Optional[torch.npu.NPUGraph] = None |
| 27 | + output: Optional[Any] = None |
| 28 | + |
| 29 | + # for aclgraph debugging, track the input addresses |
| 30 | + # during capture, and check if they are the same during replay |
| 31 | + input_addresses: Optional[list[int]] = None |
| 32 | + |
| 33 | + |
| 34 | +class ACLGraphWrapper: |
| 35 | + """Wraps a runnable to add acl graph capturing and replaying ability. And |
| 36 | + provide attribute access to the underlying `runnable` via `__getattr__`. |
| 37 | +
|
| 38 | + The workflow of this wrapper in the aclgraph dispatching is as follows: |
| 39 | + 1. At initialization, a runtime mode is assigned to the wrapper (FULL or |
| 40 | + PIECEWISE). |
| 41 | + 2. At runtime, the wrapper receives a runtime_mode and a |
| 42 | + batch_descriptor(key) from the forward context and blindly trust them |
| 43 | + for aclgraph dispatching. |
| 44 | + 3. If runtime_mode is NONE or runtime_mode does not match the mode of the |
| 45 | + wrapper, just call the runnable directly. |
| 46 | + 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, |
| 47 | + the wrapper will perform aclgraph capture(if key does not exist, create |
| 48 | + a new entry and cache it) or replay (if key exists in the cache). |
| 49 | +
|
| 50 | + Note: ACLGraphWrapper does not store persistent buffers or copy any |
| 51 | + runtime inputs into that buffers for replay. We assume implementing them |
| 52 | + is done outside of the wrapper. That is because we do not make any |
| 53 | + assumption on the dynamic shape (batch size) of the runtime inputs, as a |
| 54 | + trade-off for staying orthogonal to compilation logic. Nevertheless, |
| 55 | + tracing and checking the input addresses to be consistent during replay is |
| 56 | + guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". |
| 57 | + """ |
| 58 | + |
| 59 | + def __init__(self, |
| 60 | + runnable: Callable, |
| 61 | + vllm_config: VllmConfig, |
| 62 | + runtime_mode: CUDAGraphMode, |
| 63 | + graph_pool: Any = None, |
| 64 | + cudagraph_options: Optional[CUDAGraphOptions] = None): |
| 65 | + self.runnable = runnable |
| 66 | + self.vllm_config = vllm_config |
| 67 | + self.graph_pool = graph_pool |
| 68 | + self.runtime_mode = runtime_mode |
| 69 | + self.compilation_config = vllm_config.compilation_config |
| 70 | + |
| 71 | + self.first_run_finished = False |
| 72 | + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" |
| 73 | + |
| 74 | + # assert runtime_mode is not NONE(no aclgraph), otherwise, we don't |
| 75 | + # need to initialize a ACLGraphWrapper. |
| 76 | + assert self.runtime_mode != CUDAGraphMode.NONE |
| 77 | + if self.graph_pool is None: |
| 78 | + self.graph_pool = current_platform.get_global_graph_pool() |
| 79 | + |
| 80 | + if cudagraph_options is None: |
| 81 | + cudagraph_options = CUDAGraphOptions() |
| 82 | + self.aclgraph_options = cudagraph_options |
| 83 | + # the entries for different batch descriptors that we need to capture |
| 84 | + # aclgraphs for. |
| 85 | + self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry]\ |
| 86 | + = {} |
| 87 | + |
| 88 | + def __getattr__(self, key: str): |
| 89 | + # allow accessing the attributes of the runnable. |
| 90 | + if hasattr(self.runnable, key): |
| 91 | + return getattr(self.runnable, key) |
| 92 | + raise AttributeError(f"Attribute {key} not exists in the runnable of " |
| 93 | + f"aclgraph wrapper: {self.runnable}") |
| 94 | + |
| 95 | + def unwrap(self) -> Callable: |
| 96 | + # in case we need to access the original runnable. |
| 97 | + return self.runnable |
| 98 | + |
| 99 | + def __call__(self, *args, **kwargs): |
| 100 | + forward_context = get_forward_context() |
| 101 | + batch_descriptor = forward_context.batch_descriptor |
| 102 | + aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode |
| 103 | + |
| 104 | + if aclgraph_runtime_mode == CUDAGraphMode.NONE or \ |
| 105 | + aclgraph_runtime_mode != self.runtime_mode: |
| 106 | + # CUDAGraphMode.NONE could mean the profile run, a warmup run, or |
| 107 | + # running without aclgraphs. |
| 108 | + # We do not trigger capture/replay if the runtime mode is not |
| 109 | + # matches. This enables properly dispatching to the correct |
| 110 | + # CUDAGraphWrapper when nesting multiple instances with different |
| 111 | + # runtime modes. |
| 112 | + return self.runnable(*args, **kwargs) |
| 113 | + |
| 114 | + if batch_descriptor not in self.concrete_aclgraph_entries: |
| 115 | + # create a new entry for this batch descriptor |
| 116 | + self.concrete_aclgraph_entries[batch_descriptor] = \ |
| 117 | + ACLGraphEntry(batch_descriptor=batch_descriptor) |
| 118 | + |
| 119 | + entry = self.concrete_aclgraph_entries[batch_descriptor] |
| 120 | + |
| 121 | + if entry.aclgraph is None: |
| 122 | + if self.aclgraph_options.debug_log_enable: |
| 123 | + # Since we capture aclgraph for many different shapes and |
| 124 | + # capturing is fast, we don't need to log it for every |
| 125 | + # shape. E.g. we only log it for the first subgraph in |
| 126 | + # piecewise mode. |
| 127 | + logger.debug("Capturing a aclgraph on (%s,%s)", |
| 128 | + self.runtime_mode.name, entry.batch_descriptor) |
| 129 | + # validate that aclgraph capturing is legal at this point. |
| 130 | + validate_cudagraph_capturing_enabled() |
| 131 | + |
| 132 | + input_addresses = [ |
| 133 | + x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
| 134 | + ] |
| 135 | + entry.input_addresses = input_addresses |
| 136 | + aclgraph = torch.npu.NPUGraph() |
| 137 | + |
| 138 | + with ExitStack() as stack: |
| 139 | + if self.aclgraph_options.gc_disable: |
| 140 | + # during every model forward for piecewise aclgraph |
| 141 | + # mode, we will capture many pieces of aclgraphs |
| 142 | + # (roughly one per layer). running gc again and again |
| 143 | + # across layers will make the aclgraph capture very slow. |
| 144 | + # therefore, we only run gc for the first graph, |
| 145 | + # and disable gc for the rest of the graphs. |
| 146 | + stack.enter_context(patch("gc.collect", lambda: None)) |
| 147 | + stack.enter_context( |
| 148 | + patch("torch.npu.empty_cache", lambda: None)) |
| 149 | + |
| 150 | + # mind-exploding: carefully manage the reference and memory. |
| 151 | + with torch.npu.graph(aclgraph, pool=self.graph_pool): |
| 152 | + # `output` is managed by pytorch's aclgraph pool |
| 153 | + output = self.runnable(*args, **kwargs) |
| 154 | + if self.aclgraph_options.weak_ref_output: |
| 155 | + # by converting it to weak ref, |
| 156 | + # the original `output` will immediately be released |
| 157 | + # to save memory. It is only safe to do this for |
| 158 | + # the last graph in piecewise aclgraph mode, because |
| 159 | + # the output of the last graph will not be used by |
| 160 | + # any other acl graph. |
| 161 | + output = weak_ref_tensors(output) |
| 162 | + |
| 163 | + # here we always use weak ref for the output |
| 164 | + # to save memory |
| 165 | + entry.output = weak_ref_tensors(output) |
| 166 | + entry.aclgraph = aclgraph |
| 167 | + |
| 168 | + compilation_counter.num_cudagraph_captured += 1 |
| 169 | + |
| 170 | + # important: we need to return the output, rather than |
| 171 | + # the weak ref of the output, so that pytorch can correctly |
| 172 | + # manage the memory during acl graph capture |
| 173 | + return output |
| 174 | + |
| 175 | + if self.is_debugging_mode: |
| 176 | + # check if the input addresses are the same |
| 177 | + new_input_addresses = [ |
| 178 | + x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
| 179 | + ] |
| 180 | + assert new_input_addresses == entry.input_addresses, ( |
| 181 | + f"Input addresses for aclgraphs are different " |
| 182 | + f"during replay. Expected {entry.input_addresses}, " |
| 183 | + f"got {new_input_addresses}") |
| 184 | + |
| 185 | + entry.aclgraph.replay() |
| 186 | + return entry.output |
0 commit comments