Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_ascend/patch/platform/patch_0_9_0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import vllm_ascend.patch.platform.patch_0_9_0.patch_distributed # noqa
116 changes: 116 additions & 0 deletions vllm_ascend/patch/platform/patch_0_9_0/patch_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file should be imported in patch_0_9_0/__init__.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
from vllm.distributed import utils


def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)

store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)

group_rank = rank
group_size = world_size

# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)

# TODO(Yizhou): The reason we need to set options while vllm does not
# seems to be related to the version of PyTorch. In the latest version,
# there is no need to set options. While in the older version, 2.5.1
# specifically, we need to set options.
options = ProcessGroup.Options(backend=backend)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL

backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout

backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
elif backend == "hccl":
from torch.distributed import is_hccl_available
assert is_hccl_available()
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()

pg._register_backend(device, backend_type, backend_class)

return pg


utils.stateless_init_torch_distributed_process_group = stateless_init_torch_distributed_process_group
135 changes: 21 additions & 114 deletions vllm_ascend/patch/platform/patch_common/patch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
# Adapted from vllm/model_executor/models/qwen2_vl.py
# This file is a part of the vllm-ascend project.

import torch
import vllm
import vllm.distributed
import vllm.envs as envs
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
from vllm.config import ParallelConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed.utils import \
stateless_init_torch_distributed_process_group
from vllm.v1.engine.core import DPEngineCoreProc


def ascend_destroy_model_parallel():
Expand All @@ -48,112 +46,6 @@ def ascend_destroy_model_parallel():
destory_ascend_model_parallel()


def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.

# TODO: ask for help from PyTorch team if we need the `broadcast` operation.

This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?

One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.

Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)

store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)

group_rank = rank
group_size = world_size

# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)

# TODO(Yizhou): The reason we need to set options while vllm does not
# seems to be related to the version of PyTorch. In the latest version,
# there is no need to set options. While in the older version, 2.5.1
# specifically, we need to set options.
options = ProcessGroup.Options(backend=backend)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL

backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout

backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
elif backend == "hccl":
from torch.distributed import is_hccl_available
assert is_hccl_available()
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()

pg._register_backend(device, backend_type, backend_class)

return pg


def parallel_config_get_dp_port(self) -> int:
"""
We might need to initialize process groups in multiple
Expand All @@ -171,7 +63,7 @@ def parallel_config_get_dp_port(self) -> int:
return port


def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
def stateless_init_dp_group(self) -> "ProcessGroup":
# TODO(Yizhou): Currently we have to set the backend to gloo
# because in vllm.config.ParallelConfig.has_unfinished_dp the
# device is set to cpu. We need to fix this in the future.
Expand All @@ -187,6 +79,21 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
return dp_group


def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure NPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local

assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size

self.local_dp_rank = local_dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0


vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
DPEngineCoreProc._init_data_parallel = _init_data_parallel
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group
45 changes: 45 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

import logging
import os
from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Tuple

import torch
import vllm.envs as envs
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum

Expand Down Expand Up @@ -249,3 +252,45 @@ def get_piecewise_backend_cls(cls) -> str:
Get piecewise backend class for piecewise graph.
"""
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa

@classmethod
def stateless_init_device_torch_dist_pg(
Copy link
Collaborator

@wangxiyuan wangxiyuan Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this func called?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is called by vllm. Please refer to this pr: vllm-project/vllm#18763

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, it's a change for main branch

cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
from torch.distributed import is_hccl_available
from torch_npu._C._distributed_c10d import ProcessGroupHCCL

assert is_hccl_available()

# TODO(Yizhou): The reason we need to set options while vllm does not
# seems to be related to the version of PyTorch. In the latest version,
# there is no need to set options. While in the older version, 2.5.1
# specifically, we need to set options.
options = ProcessGroup.Options(backend=backend)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)

backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout

backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM

pg._register_backend(device, backend_type, backend_class)
return pg
9 changes: 8 additions & 1 deletion vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def __init__(
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)

# NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in
# vllm_ascend, we need to set the device id manually.
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
world_size = self.vllm_config.parallel_config.world_size
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank

Comment on lines +79 to +85
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you're doing this. The vLLM set the ASCEND_RT_VISIBLE_DEVICES correctly. @wangxiyuan

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, however, due to unresolved issues with this environment variable, it cannot be configured at runtime, which will lead to multiple processes running on a single device.

You can try revert this PR and run vllm/examples/offline_inference/data_parallel.py and it should reproduce the bug.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to file a issue to record this, I remember the new version cann and pytorch is required?

# Try to import mindie_turbo to accelerate vLLM inference.
try_register_lib(
"mindie_turbo",
Expand All @@ -102,7 +109,7 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:

def init_device(self):
if self.device_config.device.type == "npu":
self.device = torch.device(f"npu:{self.local_rank}")
self.device = torch.device(f"npu:{self.local_rank_across_dp}")
NPUPlatform.set_device(self.device)
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
Expand Down
Loading