Skip to content

Commit bc5dfaf

Browse files
committed
Add set_env_var helper and add pattern matcher debug utility
Signed-off-by: Luka Govedic <lgovedic@redhat.com>
1 parent 40c4388 commit bc5dfaf

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

vllm/compilation/pass_manager.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from contextlib import ExitStack
5+
46
from torch import fx as fx
57

8+
from vllm import envs
69
from vllm.config import VllmConfig
710
from vllm.logger import init_logger
811
from vllm.platforms import current_platform
12+
from vllm.utils import set_env_var
913

1014
if current_platform.is_cuda_alike():
1115
from .fusion import FusionPass
@@ -43,13 +47,20 @@ def __init__(self):
4347
self.passes: list[VllmInductorPass] = []
4448

4549
def __call__(self, graph: fx.Graph):
46-
shape = get_pass_context().runtime_shape
47-
for pass_ in self.passes:
48-
if pass_.is_applicable_for_shape(shape):
49-
pass_(graph)
50-
51-
# always run fix_functionalization last
52-
self.fix_functionalization(graph)
50+
with ExitStack() as stack:
51+
if envs.VLLM_PATTERN_MATCH_DEBUG is not None:
52+
# and get_tensor_model_parallel_rank() == 0:
53+
stack.enter_context(
54+
set_env_var('TORCHINDUCTOR_PATTERN_MATCH_DEBUG',
55+
envs.VLLM_PATTERN_MATCH_DEBUG))
56+
57+
shape = get_pass_context().runtime_shape
58+
for pass_ in self.passes:
59+
if pass_.is_applicable_for_shape(shape):
60+
pass_(graph)
61+
62+
# always run fix_functionalization last
63+
self.fix_functionalization(graph)
5364

5465
def configure(self, config: VllmConfig):
5566
self.pass_config = config.compilation_config.pass_config

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@
160160
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
161161
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
162162
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
163+
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE: bool = True
164+
VLLM_USE_STANDALONE_COMPILE: bool = True
165+
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
163166

164167

165168
def get_default_cache_root():
@@ -363,6 +366,10 @@ def get_vllm_port() -> Optional[int]:
363366
"VLLM_USE_STANDALONE_COMPILE":
364367
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1",
365368

369+
# Debug pattern matching inside custom passes
370+
"VLLM_PATTERN_MATCH_DEBUG":
371+
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
372+
366373
# local rank of the process in the distributed setting, used to determine
367374
# the GPU device id
368375
"LOCAL_RANK":

vllm/utils/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3346,3 +3346,16 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
33463346
pid = os.getpid()
33473347
_add_prefix(sys.stdout, process_name, pid)
33483348
_add_prefix(sys.stderr, process_name, pid)
3349+
3350+
3351+
@contextlib.contextmanager
3352+
def set_env_var(key, value):
3353+
old = os.environ.get(key)
3354+
os.environ[key] = value
3355+
try:
3356+
yield
3357+
finally:
3358+
if old is None:
3359+
del os.environ[key]
3360+
else:
3361+
os.environ[key] = old

0 commit comments

Comments
 (0)