From 35c785916a9afedd3516b81dd123edbe2d81c196 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 17 Nov 2024 23:57:20 -0800 Subject: [PATCH] [4/N][torch.compile] clean up set_torch_compile_backend (#10401) Signed-off-by: youkaichao --- vllm/compilation/backends.py | 16 ++-------------- vllm/compilation/wrapper.py | 11 +++-------- vllm/config.py | 31 ++++++++++++++++++++++++++++++- vllm/platforms/tpu.py | 7 +++---- vllm/plugins/__init__.py | 14 +------------- vllm/utils.py | 9 +++++++++ vllm/worker/model_runner.py | 3 +-- 7 files changed, 49 insertions(+), 42 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 22c613931f082..0cf1e3a95fcba 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -2,15 +2,14 @@ import dataclasses import operator from contextlib import ExitStack -from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, - Union) +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch import torch import torch.fx as fx import vllm.envs as envs -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig from vllm.logger import init_logger from vllm.utils import combine_fx_passes, weak_ref_tensors @@ -684,14 +683,3 @@ def __call__(self, *args) -> Any: entry.cudagraph.replay() return entry.output - - -def select_default_backend(level: int) -> Union[str, Callable]: - if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: - backend_str = "eager" - return backend_str - assert level == CompilationLevel.PIECEWISE - - from vllm.plugins import get_current_vllm_config - compilation_config = get_current_vllm_config().compilation_config - return VllmBackend(compilation_config) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 2a1aecc11ce26..0143d0301ca1a 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -32,14 +32,9 @@ def __init__(self, # default compilation settings # compiling the forward method - # choose the compile backend - - # if the user has set the backend, use it - from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() - if backend is None: - from vllm.compilation.backends import select_default_backend - backend = select_default_backend(compilation_level) + from vllm.plugins import get_current_vllm_config + backend = get_current_vllm_config( + ).compilation_config.init_backend() compiled_callable = torch.compile( self.forward, diff --git a/vllm/config.py b/vllm/config.py index 7e37edbe594b1..14017bbdb3cf2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -22,7 +22,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - identity, print_warning_once) + identity, print_warning_once, resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel): - 1: dynamo as is. - 2: dynamo once. - 3: piecewise compilation. + - backend: the backend for compilation. It needs to be a string. + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the backend function. + We use string to avoid serialization issues when using compilation in a distributed setting. + When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). + When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). - custom_ops: fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op names to enable (prefixed with a '+'), @@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel): certain small batchsizes, where inductor is good at optimizing. """ # noqa level: int = 0 + backend: str = "" custom_ops: List[str] = Field(default_factory=list) use_inductor: bool = True @@ -2182,6 +2190,27 @@ def model_post_init(self, __context: Any) -> None: func = __import__(module).__dict__[func_name] self.inductor_compile_config[k] = func + def init_backend(self) -> Union[str, Callable]: + if self.level == CompilationLevel.NO_COMPILATION: + raise ValueError("No compilation level is set.") + + from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) + if self.level in [ + CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE + ]: + if self.backend == "": + return "eager" + if self.backend in torch_backends: + return self.backend + return resolve_obj_by_qualname(self.backend) + + # TODO: pass user-specified backend to piecewise compilation + # merge with the config use_inductor + assert self.level == CompilationLevel.PIECEWISE + from vllm.compilation.backends import VllmBackend + return VllmBackend(self) + def init_during_runtime(self): """To complete the initialization of config, we need to know the compile context, which is only available diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index c2e22bfc09f22..643db835c85ff 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -3,8 +3,6 @@ import torch -from vllm.plugins import set_torch_compile_backend - from .interface import Platform, PlatformEnum if TYPE_CHECKING: @@ -12,8 +10,6 @@ else: VllmConfig = None -set_torch_compile_backend("openxla") - class TpuPlatform(Platform): _enum = PlatformEnum.TPU @@ -38,3 +34,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.level = CompilationLevel.DYNAMO_ONCE assert compilation_config.level < CompilationLevel.PIECEWISE,\ "TPU does not support Inductor." + + if compilation_config.backend == "": + compilation_config.backend = "openxla" diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index c20b9ec891d5d..a0c73a752b5e8 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,6 +1,6 @@ import logging from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Optional import vllm.envs as envs @@ -50,18 +50,6 @@ def load_general_plugins(): logger.exception("Failed to load plugin %s", plugin.name) -_torch_compile_backend: Optional[Union[Callable, str]] = None - - -def set_torch_compile_backend(backend: Union[Callable, str]): - global _torch_compile_backend - _torch_compile_backend = backend - - -def get_torch_compile_backend() -> Optional[Union[Callable, str]]: - return _torch_compile_backend - - _compilation_config: Optional[CompilationConfig] = None diff --git a/vllm/utils.py b/vllm/utils.py index 111460a29de47..5d0514cd9d168 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1600,3 +1600,12 @@ def direct_register_custom_op( my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully qualified name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fd89f95445565..fb5813651680b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1143,8 +1143,7 @@ def load_model(self) -> None: if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): - from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or "eager" + backend = self.vllm_config.compilation_config.init_backend() self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,