Skip to content

Commit 4ba1bad

Browse files
author
w00800020
committed
Add basic wrapper for torchair multistream utilities
Helps to unite pathes where multistream is turned on or off. Signed-off-by: w00800020 <weijinyi3@huawei.com>
1 parent f0485e2 commit 4ba1bad

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

vllm_ascend/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,28 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/worker.py
1818
#
1919

20+
import contextlib
2021
import math
2122
from typing import TYPE_CHECKING
2223

2324
import torch
25+
import torchair # type: ignore # noqa: F401
2426
from packaging.version import InvalidVersion, Version
2527
from vllm.logger import logger
2628

2729
import vllm_ascend.envs as envs
2830

31+
try:
32+
from torchair.scope import \
33+
npu_stream_switch as _npu_stream_switch # type: ignore
34+
from torchair.scope import \
35+
npu_wait_tensor as _npu_wait_tensor # type: ignore
36+
except ImportError:
37+
from torchair.ops import \
38+
NpuStreamSwitch as _npu_stream_switch # type: ignore
39+
from torchair.ops import \
40+
npu_wait_tensor as _npu_wait_tensor # type: ignore
41+
2942
if TYPE_CHECKING:
3043
from vllm.config import VllmConfig
3144
else:
@@ -173,3 +186,14 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
173186

174187
def dispose_tensor(x: torch.Tensor):
175188
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
189+
190+
191+
def npu_stream_switch(tag: str, priority: int = 0, enabled: bool = True):
192+
return _npu_stream_switch(
193+
tag, priority) if enabled else contextlib.nullcontext()
194+
195+
196+
def npu_wait_tensor(self: torch.Tensor,
197+
dependency: torch.Tensor,
198+
enabled: bool = True):
199+
return _npu_wait_tensor(self, dependency) if enabled else self

0 commit comments

Comments
 (0)