diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 335ffd83fcd7a..6989c94d46a89 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -12,4 +12,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9f449ff650b90..235db72eee4b9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -173,6 +173,7 @@ steps: - vllm/ commands: - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py - label: Vision Language Models Test # 42min diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py new file mode 100644 index 0000000000000..cef516ade27eb --- /dev/null +++ b/tests/compile/test_wrapper.py @@ -0,0 +1,59 @@ +from typing import Optional + +import torch + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther + + +class MyMod(torch.nn.Module): + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + if cache is not None: + return x + cache + return x * 2 + + +class MyWrapper(TorchCompileWrapperWithCustomDispacther): + + def __init__(self, model): + self.model = model + compiled_callable = torch.compile(self.forward, backend="eager") + super().__init__(compiled_callable) + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # this is the function to be compiled + return self.model(x, cache) + + def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # let torch.compile compile twice + if len(self.compiled_codes) == 2: + dispatch_id = 0 if cache is None else 1 + with self.dispatch_to_code(dispatch_id): + return self.forward(x, cache) + else: + return self.compiled_callable(x, cache) + + +def test_torch_compile_wrapper(): + mod = MyMod() + wrappers = [] + for i in range(3): + torch._dynamo.reset() + wrapper = MyWrapper(mod) + wrappers.append(wrapper) + x = torch.tensor([1]) + wrapper(x, None) # profile run, compile + # create a cache tensor + cache = torch.tensor([2]) + wrapper(x, cache) # warm up with cache, recompile + + # for new input, dispatch to the compiled code directly + new_x = torch.tensor([3]) + assert wrapper(new_x, + None).item() == 6 # dispatch to the first compiled code + assert wrapper( + new_x, cache).item() == 5 # dispatch to the second compiled code + + for wrapper in wrappers: + # make sure they have independent compiled codes + assert len(wrapper.compiled_codes) == 2 diff --git a/tests/tpu/__init__.py b/tests/tpu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py new file mode 100644 index 0000000000000..7f3fb595321ad --- /dev/null +++ b/tests/tpu/test_custom_dispatcher.py @@ -0,0 +1,9 @@ +from ..utils import compare_two_settings + + +def test_custom_dispatcher(): + compare_two_settings("google/gemma-2b", + arg1=["--enforce-eager"], + arg2=["--enforce-eager"], + env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, + env2={}) diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py new file mode 100644 index 0000000000000..c3d863299dd06 --- /dev/null +++ b/vllm/compilation/wrapper.py @@ -0,0 +1,81 @@ +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, List + +import torch + +import vllm.envs as envs + + +class TorchCompileWrapperWithCustomDispacther: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, compiled_callable: Callable): + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/envs.py b/vllm/envs.py index 4faafd9daf304..5906984163295 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -196,6 +196,10 @@ def get_default_config_root(): # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": + lambda: + (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in + ("true", "1")), # local rank of the process in the distributed setting, used to determine # the GPU device id diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 01daa64b5a32f..a7ceb84effe91 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -10,6 +10,7 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger @@ -144,11 +145,7 @@ def load_model(self) -> None: ) model = model.eval() xm.wait_device_ops() - model = ModelWrapper(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) + self.model = ModelWrapper(model) def _dummy_run( self, @@ -235,8 +232,15 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) + self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + num_samples, + kv_caches, + is_prompt=is_prompt) def warmup_model( self, @@ -530,7 +534,7 @@ def _execute_model(*args): if getattr(arg, "context_lens", None) is not None: arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) - return self.model(*new_args) + return self.model(*new_args, is_prompt=is_prompt) num_prefills = model_input.attn_metadata.num_prefills is_prompt = num_prefills > 0 @@ -601,11 +605,32 @@ def _execute_model(*args): return [SamplerOutput(sampler_outputs)] -class ModelWrapper(nn.Module): +class ModelWrapper(TorchCompileWrapperWithCustomDispacther): def __init__(self, model: nn.Module): - super().__init__() self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) def forward( self,