Skip to content

Commit 6c97336

Browse files
authored
[Bugfix] Fix aclgraph not enabled by default (#2590)
### What this PR does / why we need it? As vllm will set `cudagraph_mode` to `NONE` before `check_and_update_config` in post init of `VllmConfig` (https://github.com/vllm-project/vllm/blob/5da4f5d857933329aaca779e3a81f1385c84e34a/vllm/config/__init__.py#L3630), we always have `cudagraph_mode` isn't `None`, thus we must remove this check and add it when the related adaption in vllm is done. part of #2577, will add the e2e test on applying reply after the CI refactor is done ### How was this patch tested? CI passed with existing test. - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@f48a9af Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent cf96366 commit 6c97336

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

tests/ut/test_platform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import timedelta
44
from unittest.mock import MagicMock, patch
55

6+
import pytest
67
import torch
78
from torch.distributed import ProcessGroup
89
from torch.distributed.distributed_c10d import PrefixStore
@@ -318,6 +319,8 @@ def test_check_and_update_config_unsupported_compilation_level(
318319
CUDAGraphMode.NONE,
319320
)
320321

322+
@pytest.mark.skip(
323+
"Revert me when vllm support setting cudagraph_mode on oot platform")
321324
@patch("vllm_ascend.utils.is_310p", return_value=False)
322325
@patch("vllm_ascend.ascend_config.check_ascend_config")
323326
@patch("vllm_ascend.ascend_config.init_ascend_config")

vllm_ascend/compilation/acl_graph.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
1414
from vllm.config import CUDAGraphMode, VllmConfig
1515
from vllm.forward_context import BatchDescriptor, get_forward_context
16-
from vllm.logger import init_logger
16+
from vllm.logger import logger
1717
from vllm.platforms import current_platform
1818
from vllm.utils import weak_ref_tensors
1919

20-
logger = init_logger(__name__)
21-
2220

2321
@dataclasses.dataclass
2422
class ACLGraphEntry:
@@ -182,5 +180,6 @@ def __call__(self, *args, **kwargs):
182180
f"during replay. Expected {entry.input_addresses}, "
183181
f"got {new_input_addresses}")
184182

183+
logger.info_once("Replaying aclgraph")
185184
entry.aclgraph.replay()
186185
return entry.output

vllm_ascend/platform.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -146,23 +146,23 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
146146

147147
compilation_config.cudagraph_num_of_warmups = 1
148148

149-
if compilation_config.cudagraph_mode is None:
150-
# if cudagraph_mode is not explicitly set by users, set default value
151-
if compilation_config.level == CompilationLevel.PIECEWISE:
152-
compilation_config.cudagraph_mode = \
153-
CUDAGraphMode.PIECEWISE
154-
elif compilation_config.level not in [
155-
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
156-
]:
157-
logger.warning(
158-
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
159-
compilation_config.level)
160-
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
161-
else:
162-
logger.warning(
163-
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
164-
)
165-
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
149+
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
150+
# if cudagraph_mode is not explicitly set by users, set default value
151+
if compilation_config.level == CompilationLevel.PIECEWISE:
152+
compilation_config.cudagraph_mode = \
153+
CUDAGraphMode.PIECEWISE
154+
elif compilation_config.level not in [
155+
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
156+
]:
157+
logger.warning(
158+
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
159+
compilation_config.level)
160+
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
161+
else:
162+
logger.warning(
163+
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
164+
)
165+
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
166166

167167
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
168168
if ascend_config.torchair_graph_config.enabled:

0 commit comments

Comments
 (0)