Skip to content

Commit 6fe137e

Browse files
committed
Add monckey patch
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent c59375c commit 6fe137e

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

vllm_ascend/patch/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from vllm_ascend.patch import patch_commnicator # noqa
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
# This file is used to monckey patch communicator in vllm to support ascend.
18+
# Remove this file when vllm support by
19+
# https://github.com/vllm-project/vllm/pull/11324.
20+
21+
from vllm import platforms
22+
from vllm.distributed.parallel_state import GroupCoordinator
23+
from vllm.utils import resolve_obj_by_qualname
24+
25+
26+
class GroupCoordinatorPatch(GroupCoordinator):
27+
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
30+
31+
device_comm_cls = resolve_obj_by_qualname(
32+
platforms.current_platform.get_device_communicator_cls())
33+
self.communicator = device_comm_cls(group=self.device_group,
34+
unique_name=self.unique_name)
35+
36+
def all_reduce(self, input_):
37+
# Bypass the function if we are using only 1 GPU.
38+
if self.world_size == 1:
39+
return input_
40+
41+
return self.communicator.all_reduce(input_)
42+
43+
def gather(self, input_, dst=0, dim=-1):
44+
# Bypass the function if we are using only 1 GPU.
45+
if self.world_size == 1:
46+
return input_
47+
assert -input_.dim() <= dim < input_.dim(), (
48+
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
49+
if dim < 0:
50+
# Convert negative dim to positive.
51+
dim += input_.dim()
52+
53+
return self.communicator.gather(input_, dst, dim)
54+
55+
def all_gather(self, input_, dim=-1):
56+
# Bypass the function if we are using only 1 GPU.
57+
if self.world_size == 1:
58+
return input_
59+
assert -input_.dim() <= dim < input_.dim(), (
60+
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
61+
return self.communicator.all_gather(input_, dim)
62+
63+
64+
GroupCoordinator = GroupCoordinatorPatch

vllm_ascend/platform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def mem_get_info(cls) -> Tuple[int, int]:
8989
@classmethod
9090
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9191
# Register ops when setup.
92+
# Register patch
9293
from vllm_ascend import ops # noqa: F401
94+
from vllm_ascend import patch # noqa: F401
9395

9496
parallel_config = vllm_config.parallel_config
9597
if parallel_config.worker_cls == "auto":

0 commit comments

Comments
 (0)