Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _read_requirements(filename: str) -> List[str]:
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
],
packages=find_packages(exclude=("docs", "examples", "tests*", "patch")),
packages=find_packages(exclude=("docs", "examples", "tests*")),
python_requires=">=3.9",
install_requires=get_requirements(),
extras_require={},
Expand Down
6 changes: 3 additions & 3 deletions vllm_ascend/patch/patch_commnicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
# https://github.com/vllm-project/vllm/pull/11324.

import torch
from vllm.distributed.parallel_state import GroupCoordinator
import vllm
from vllm.utils import resolve_obj_by_qualname


class GroupCoordinatorPatch(GroupCoordinator):
class GroupCoordinatorPatch(vllm.distributed.parallel_state.GroupCoordinator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -66,4 +66,4 @@ def all_gather(self, input_, dim=-1):
return self.communicator.all_gather(input_, dim)


GroupCoordinator = GroupCoordinatorPatch
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch
3 changes: 1 addition & 2 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def mem_get_info(cls) -> Tuple[int, int]:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# Register ops and patch when setup.
# Register ops when setup.
from vllm_ascend import ops # noqa: F401
from vllm_ascend import patch # noqa: F401

parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ def init_worker_distributed_environment(
backend: str = "hccl") -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
# register communicator patch before init dist env
from vllm_ascend import patch # noqa: F401

init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, backend)
Expand Down