Skip to content

Commit 1449a8c

Browse files
committed
fix: vllm serve on Apple silicon
Right now commands like `vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0` on Apple silicon fail with triton errors like these. ``` $ vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 INFO 04-30 09:33:49 [importing.py:17] Triton not installed or not compatible; certain GPU-related functions will not be available. WARNING 04-30 09:33:49 [importing.py:28] Triton is not installed. Using dummy decorators. Install it via `pip install triton` to enable kernelcompilation. INFO 04-30 09:33:49 [importing.py:53] Triton module has been replaced with a placeholder. INFO 04-30 09:33:50 [__init__.py:239] Automatically detected platform cpu. Traceback (most recent call last): File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/bin/vllm", line 5, in <module> from vllm.entrypoints.cli.main import main File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/entrypoints/cli/main.py", line 7, in <module> import vllm.entrypoints.cli.benchmark.main File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/entrypoints/cli/benchmark/main.py", line 6, in <module> import vllm.entrypoints.cli.benchmark.throughput File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/entrypoints/cli/benchmark/throughput.py", line 4, in <module> from vllm.benchmarks.throughput import add_cli_args, main File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/benchmarks/throughput.py", line 18, in <module> from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/benchmarks/datasets.py", line 34, in <module> from vllm.lora.utils import get_adapter_absolute_path File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/lora/utils.py", line 15, in <module> from vllm.lora.fully_sharded_layers import ( File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/lora/fully_sharded_layers.py", line 14, in <module> from vllm.lora.layers import (ColumnParallelLinearWithLoRA, File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/lora/layers.py", line 29, in <module> from vllm.model_executor.layers.logits_processor import LogitsProcessor File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/model_executor/layers/logits_processor.py", line 13, in <module> from vllm.model_executor.layers.vocab_parallel_embedding import ( File "/Users/dxia/src/github.com/vllm-project/vllm/vllm/model_executor/layers/vocab_parallel_embedding.py", line 139, in <module> @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2543, in fn return compile( ^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2572, in compile return torch._dynamo.optimize( ^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 944, in optimize return _optimize(rebuild_ctx, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 998, in _optimize backend = get_compiler_fn(backend) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 878, in get_compiler_fn from .repro.after_dynamo import wrap_backend_debug File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 35, in <module> from torch._dynamo.debug_utils import ( File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/debug_utils.py", line 44, in <module> from torch._dynamo.testing import rand_strided File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/testing.py", line 33, in <module> from torch._dynamo.backends.debugging import aot_eager File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/debugging.py", line 35, in <module> from functorch.compile import min_cut_rematerialization_partition File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/functorch/compile/__init__.py", line 2, in <module> from torch._functorch.aot_autograd import ( File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 26, in <module> from torch._inductor.output_code import OutputCode File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 52, in <module> from .runtime.autotune_cache import AutotuneCacheBundler File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/autotune_cache.py", line 23, in <module> from .triton_compat import Config File "/Users/dxia/src/github.com/vllm-project/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_compat.py", line 16, in <module> from triton import Config ImportError: cannot import name 'Config' from 'triton' (unknown location) ``` We cannot install `triton` on Apple silicon because there are no [available distributions][1]. This change adds more placeholders for triton modules and classes that are imported when calling `vllm serve`. [1]: https://pypi.org/project/triton/#files Signed-off-by: David Xia <david@davidxia.com>
1 parent 42d9a2c commit 1449a8c

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

vllm/triton_utils/importing.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
import types
5+
from abc import ABC
56
from importlib.util import find_spec
67

78
from vllm.logger import init_logger
@@ -25,6 +26,8 @@ def __init__(self):
2526
self.autotune = self._dummy_decorator("autotune")
2627
self.heuristics = self._dummy_decorator("heuristics")
2728
self.language = TritonLanguagePlaceholder()
29+
self.Config = self._dummy_decorator("Config")
30+
self.__version__ = ""
2831
logger.warning_once(
2932
"Triton is not installed. Using dummy decorators. "
3033
"Install it via `pip install triton` to enable kernel"
@@ -43,11 +46,36 @@ class TritonLanguagePlaceholder(types.ModuleType):
4346

4447
def __init__(self):
4548
super().__init__("triton.language")
46-
self.constexpr = None
49+
self.constexpr = lambda x: x
4750
self.dtype = None
51+
self.extra = None
52+
self.math = None
53+
self.tensor = None
54+
55+
class TritonCompilerPlaceholder(types.ModuleType):
56+
57+
def __init__(self):
58+
super().__init__("triton.compiler")
59+
self.CompiledKernel = ABC
60+
61+
class TritonRuntimeAutotunerPlaceholder(types.ModuleType):
62+
63+
def __init__(self):
64+
super().__init__("triton.runtime.autotuner")
65+
self.OutOfResources = ABC
66+
67+
class TritonRuntimeJitPlaceholder(types.ModuleType):
68+
69+
def __init__(self):
70+
super().__init__("triton.runtime.jit")
71+
self.KernelInterface = ABC
4872

4973
sys.modules['triton'] = TritonPlaceholder()
5074
sys.modules['triton.language'] = TritonLanguagePlaceholder()
75+
sys.modules['triton.compiler'] = TritonCompilerPlaceholder()
76+
sys.modules[
77+
'triton.runtime.autotuner'] = TritonRuntimeAutotunerPlaceholder()
78+
sys.modules['triton.runtime.jit'] = TritonRuntimeJitPlaceholder()
5179

5280
if 'triton' in sys.modules:
5381
logger.info("Triton module has been replaced with a placeholder.")

0 commit comments

Comments
 (0)