-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Distributed] Add send and recv helpers #5719
Changes from 6 commits
e08e665
ae4a17a
c34f440
8b16ef1
ac13472
68cd3c3
1c57591
2c9fa20
25ad2bb
31ce144
83448fb
8b4e1fd
8cd5a9c
9d08595
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,11 @@ | |
from unittest.mock import patch | ||
|
||
import torch | ||
import torch.distributed | ||
from torch.distributed import Backend, ProcessGroup | ||
from torch.distributed.distributed_c10d import (_get_pg_default_device, | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_object_to_tensor, | ||
_tensor_to_object) | ||
|
||
import vllm.envs as envs | ||
from vllm.logger import init_logger | ||
|
@@ -342,6 +346,62 @@ def broadcast_object_list(self, | |
group=self.device_group) | ||
return obj_list | ||
|
||
def send_object(self, | ||
obj: Any, | ||
dst: int, | ||
group: Optional[ProcessGroup] = None) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we know we are sending python object here, so it is reasonable to just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
"""Send the input object list to the destination rank.""" | ||
assert dst < self.world_size, f"Invalid dst rank ({dst})" | ||
|
||
# Bypass the function if we are using only 1 GPU. | ||
if self.world_size == 1: | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return obj | ||
|
||
current_device = _get_pg_default_device(group) | ||
# Serialize object to tensor. | ||
object_tensor, size_tensor = _object_to_tensor(obj, current_device, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. safe to directly use "cpu" here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. safe to just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done these simplifications |
||
group) | ||
|
||
# Send object size | ||
torch.distributed.send(size_tensor, dst=dst, group=group) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed it so |
||
|
||
# Send object | ||
torch.distributed.send(object_tensor, dst=dst, group=group) | ||
|
||
return None | ||
|
||
def recv_object(self, | ||
src: int, | ||
group: Optional[ProcessGroup] = None) -> Any: | ||
"""Receive the input object list from the source rank.""" | ||
assert src < self.world_size, f"Invalid src rank ({src})" | ||
|
||
# Bypass the function if we are using only 1 GPU. | ||
if self.world_size == 1: | ||
return None | ||
|
||
current_device = _get_pg_default_device(group) | ||
|
||
size_tensor = torch.empty(1, dtype=torch.long, device=current_device) | ||
|
||
# Receive object size | ||
rank_size = torch.distributed.recv(size_tensor, src=src, group=group) | ||
|
||
# Tensor to receive serialized objects into. | ||
object_tensor = torch.empty( # type: ignore[call-overload] | ||
size_tensor.item(), # type: ignore[arg-type] | ||
dtype=torch.uint8, | ||
device=current_device) | ||
|
||
rank_object = torch.distributed.recv(object_tensor, | ||
src=src, | ||
group=group) | ||
assert (rank_size == rank_object | ||
), "Mismatch in return ranks for object sizes and objects." | ||
# Deserialize objects using their stored sizes. | ||
|
||
return _tensor_to_object(object_tensor, size_tensor.item(), group) | ||
|
||
def broadcast_tensor_dict( | ||
self, | ||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, | ||
|
@@ -433,6 +493,90 @@ def broadcast_tensor_dict( | |
async_handle.wait() | ||
return tensor_dict | ||
|
||
def send_tensor_dict( | ||
self, | ||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]], | ||
dst: Optional[int] = None | ||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: | ||
"""Send the input tensor dictionary. | ||
NOTE: `dst` is the local rank of the source rank. | ||
""" | ||
# Bypass the function if we are using only 1 GPU. | ||
if not torch.distributed.is_initialized() or self.world_size == 1: | ||
return tensor_dict | ||
|
||
group = self.device_group | ||
metadata_group = self.cpu_group | ||
|
||
if dst is None: | ||
dst = self.next_rank | ||
assert dst < self.world_size, f"Invalid dst rank ({dst})" | ||
dst = self.ranks[dst] | ||
|
||
metadata_list: List[Tuple[Any, Any]] = [] | ||
assert isinstance( | ||
tensor_dict, | ||
dict), f"Expecting a dictionary, got {type(tensor_dict)}" | ||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict) | ||
# `metadata_list` lives in CPU memory. | ||
# `send_object_list` has serialization & deserialization, | ||
# all happening on CPU. Therefore, we can use the CPU group. | ||
self.send_object(metadata_list, dst=dst, group=metadata_group) | ||
for tensor in tensor_list: | ||
if tensor.numel() == 0: | ||
# Skip sending empty tensors. | ||
continue | ||
if tensor.is_cpu: | ||
# use metadata_group for CPU tensors | ||
torch.distributed.send(tensor, dst=dst, group=metadata_group) | ||
else: | ||
# use group for GPU tensors | ||
torch.distributed.send(tensor, dst=dst, group=group) | ||
return None | ||
|
||
def recv_tensor_dict( | ||
self, | ||
src: Optional[int] = None | ||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: | ||
"""Recv the input tensor dictionary. | ||
NOTE: `src` is the local rank of the source rank. | ||
""" | ||
# Bypass the function if we are using only 1 GPU. | ||
if not torch.distributed.is_initialized() or self.world_size == 1: | ||
return None | ||
|
||
group = self.device_group | ||
metadata_group = self.cpu_group | ||
|
||
if src is None: | ||
src = self.prev_rank | ||
assert src < self.world_size, f"Invalid src rank ({src})" | ||
src = self.ranks[src] | ||
|
||
recv_metadata_list = self.recv_object(src=src, group=metadata_group) | ||
tensor_dict = {} | ||
for key, value in recv_metadata_list: | ||
if isinstance(value, TensorMetadata): | ||
tensor = torch.empty(value.size, | ||
dtype=value.dtype, | ||
device=value.device) | ||
if tensor.numel() == 0: | ||
# Skip broadcasting empty tensors. | ||
tensor_dict[key] = tensor | ||
continue | ||
if tensor.is_cpu: | ||
# use metadata_group for CPU tensors | ||
torch.distributed.recv(tensor, | ||
src=src, | ||
group=metadata_group) | ||
else: | ||
# use group for GPU tensors | ||
torch.distributed.recv(tensor, src=src, group=group) | ||
tensor_dict[key] = tensor | ||
else: | ||
tensor_dict[key] = value | ||
return tensor_dict | ||
|
||
def barrier(self): | ||
"""Barrier synchronization among the group. | ||
NOTE: don't use `device_group` here! `barrier` in NCCL is | ||
|
@@ -442,6 +586,26 @@ def barrier(self): | |
""" | ||
torch.distributed.barrier(group=self.cpu_group) | ||
|
||
def send(self, tensor: torch.Tensor, dst: int) -> None: | ||
"""Sends a tensor to the destination rank in a non-blocking way""" | ||
pynccl_comm = self.pynccl_comm | ||
if pynccl_comm is not None and not pynccl_comm.disabled: | ||
pynccl_comm.send(tensor) | ||
else: | ||
torch.distributed.isend(tensor, dst, self.device_group) | ||
|
||
def recv(self, size: torch.Size, dtype: torch.dtype, | ||
src: int) -> torch.Tensor: | ||
"""Receives a tensor from the src rank.""" | ||
tensor = torch.empty(size, dtype=dtype, device=self.device) | ||
pynccl_comm = self.pynccl_comm | ||
if pynccl_comm is not None and not pynccl_comm.disabled: | ||
pynccl_comm.recv(tensor) | ||
else: | ||
req = torch.distributed.irecv(tensor, src, self.device_group) | ||
req.wait() | ||
return tensor | ||
|
||
def destroy(self): | ||
if self.device_group is not None: | ||
torch.distributed.destroy_process_group(self.device_group) | ||
|
@@ -684,6 +848,28 @@ def get_tensor_model_parallel_rank(): | |
return get_tp_group().rank_in_group | ||
|
||
|
||
def get_pipeline_model_parallel_world_size(): | ||
"""Return world size for the pipeline model parallel group.""" | ||
return get_pp_group().world_size | ||
|
||
|
||
def get_pipeline_model_parallel_rank(): | ||
"""Return my rank for the pipeline model parallel group.""" | ||
return get_pp_group().rank_in_group | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these two are not used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean they are not used currently? Planning to use them in next PRs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, they are not used. in the future i will remove legacy usage like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood, basically everything is done through |
||
|
||
def is_pipeline_model_parallel_first_rank(): | ||
"""Return True if the rank is the first rank in the | ||
pipeline model parallel group.""" | ||
return get_pp_group().rank_in_group == 0 | ||
|
||
|
||
def is_pipeline_model_parallel_last_rank(): | ||
"""Return True if the rank is the last rank in the | ||
pipeline model parallel group.""" | ||
return get_pp_group().rank_in_group == get_pp_group().world_size - 1 | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add them in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
def destroy_model_parallel(): | ||
"""Set the groups to none and destroy them.""" | ||
global _TP | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, forgot to rerun tests