|
17 | 17 | # Adapted from vllm-project/vllm/vllm/worker/worker.py |
18 | 18 | # |
19 | 19 |
|
| 20 | +import contextlib |
20 | 21 | import math |
21 | 22 | from typing import TYPE_CHECKING |
22 | 23 |
|
23 | 24 | import torch |
| 25 | +import torchair # type: ignore # noqa: F401 |
24 | 26 | from packaging.version import InvalidVersion, Version |
25 | 27 | from vllm.logger import logger |
26 | 28 |
|
27 | 29 | import vllm_ascend.envs as envs |
28 | 30 |
|
| 31 | +try: |
| 32 | + from torchair.scope import \ |
| 33 | + npu_stream_switch as _npu_stream_switch # type: ignore |
| 34 | + from torchair.scope import \ |
| 35 | + npu_wait_tensor as _npu_wait_tensor # type: ignore |
| 36 | +except ImportError: |
| 37 | + from torchair.ops import \ |
| 38 | + NpuStreamSwitch as _npu_stream_switch # type: ignore |
| 39 | + from torchair.ops import \ |
| 40 | + npu_wait_tensor as _npu_wait_tensor # type: ignore |
| 41 | + |
29 | 42 | if TYPE_CHECKING: |
30 | 43 | from vllm.config import VllmConfig |
31 | 44 | else: |
@@ -173,3 +186,14 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: |
173 | 186 |
|
174 | 187 | def dispose_tensor(x: torch.Tensor): |
175 | 188 | x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype)) |
| 189 | + |
| 190 | + |
| 191 | +def npu_stream_switch(tag: str, priority: int = 0, enabled: bool = True): |
| 192 | + return _npu_stream_switch( |
| 193 | + tag, priority) if enabled else contextlib.nullcontext() |
| 194 | + |
| 195 | + |
| 196 | +def npu_wait_tensor(self: torch.Tensor, |
| 197 | + dependency: torch.Tensor, |
| 198 | + enabled: bool = True): |
| 199 | + return _npu_wait_tensor(self, dependency) if enabled else self |
0 commit comments