|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
| 3 | +import sys |
| 4 | +import types |
3 | 5 | from importlib.util import find_spec |
4 | 6 |
|
5 | 7 | from vllm.logger import init_logger |
6 | | -from vllm.platforms import current_platform |
7 | 8 |
|
8 | 9 | logger = init_logger(__name__) |
9 | 10 |
|
10 | 11 | HAS_TRITON = ( |
11 | 12 | find_spec("triton") is not None |
12 | | - and not current_platform.is_xpu() # Not compatible |
| 13 | + or find_spec("pytorch-triton-xpu") is not None # Not compatible |
13 | 14 | ) |
14 | 15 |
|
15 | 16 | if not HAS_TRITON: |
16 | 17 | logger.info("Triton not installed or not compatible; certain GPU-related" |
17 | 18 | " functions will not be available.") |
| 19 | + |
| 20 | + class TritonPlaceholder(types.ModuleType): |
| 21 | + |
| 22 | + def __init__(self): |
| 23 | + super().__init__("triton") |
| 24 | + self.jit = self._dummy_decorator("jit") |
| 25 | + self.autotune = self._dummy_decorator("autotune") |
| 26 | + self.heuristics = self._dummy_decorator("heuristics") |
| 27 | + self.language = TritonLanguagePlaceholder() |
| 28 | + logger.warning_once( |
| 29 | + "Triton is not installed. Using dummy decorators. " |
| 30 | + "Install it via `pip install triton` to enable kernel" |
| 31 | + "compilation.") |
| 32 | + |
| 33 | + def _dummy_decorator(self, name): |
| 34 | + |
| 35 | + def decorator(func=None, **kwargs): |
| 36 | + if func is None: |
| 37 | + return lambda f: f |
| 38 | + return func |
| 39 | + |
| 40 | + return decorator |
| 41 | + |
| 42 | + class TritonLanguagePlaceholder(types.ModuleType): |
| 43 | + |
| 44 | + def __init__(self): |
| 45 | + super().__init__("triton.language") |
| 46 | + self.constexpr = None |
| 47 | + self.dtype = None |
| 48 | + |
| 49 | + sys.modules['triton'] = TritonPlaceholder() |
| 50 | + sys.modules['triton.language'] = TritonLanguagePlaceholder() |
| 51 | + |
| 52 | +if 'triton' in sys.modules: |
| 53 | + logger.info("Triton module has been replaced with a placeholder.") |
0 commit comments