Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Add send and recv helpers #5719

Merged
merged 14 commits into from
Jun 23, 2024
57 changes: 53 additions & 4 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import ray
import torch

from vllm.distributed import (broadcast_tensor_dict,
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
is_pipeline_model_parallel_first_rank,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update the test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, forgot to rerun tests

is_pipeline_model_parallel_last_rank,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)

from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
from ..utils import init_test_distributed_environment, multi_process_parallel


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -105,6 +106,46 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}

if not is_pipeline_model_parallel_first_rank():
recv_dict = get_pp_group().recv_tensor_dict()

if not is_pipeline_model_parallel_last_rank():
get_pp_group().send_tensor_dict(test_dict)

if not is_pipeline_model_parallel_first_rank():
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
Expand All @@ -113,4 +154,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target)
multi_process_parallel(tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)
5 changes: 2 additions & 3 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
get_tp_group, graph_capture)

from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment,
multi_process_tensor_parallel)
init_test_distributed_environment, multi_process_parallel)

random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
Expand Down Expand Up @@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized(tp_size, pp_size)


def multi_process_tensor_parallel(
def multi_process_parallel(
tp_size: int,
pp_size: int,
test_target,
Expand Down
186 changes: 186 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from unittest.mock import patch

import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
from torch.distributed.distributed_c10d import (_get_pg_default_device,
_object_to_tensor,
_tensor_to_object)

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -342,6 +346,62 @@ def broadcast_object_list(self,
group=self.device_group)
return obj_list

def send_object(self,
obj: Any,
dst: int,
group: Optional[ProcessGroup] = None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we know we are sending python object here, so it is reasonable to just use cpu_group.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"""Send the input object list to the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"

assert dst != self.rank, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")

current_device = _get_pg_default_device(group)
# Serialize object to tensor.
object_tensor, size_tensor = _object_to_tensor(obj, current_device,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe to directly use "cpu" here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe to just use pickle.dumps and torch.frombuffer stuff.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done these simplifications

group)

# Send object size
torch.distributed.send(size_tensor, dst=dst, group=group)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our api uses relative rank inside the group, but pytorch uses global ranks:

src (int) – Source rank from which to broadcast object_list. Source rank is based on global process group (regardless of group argument)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it so {send,recv}, {send,recv}_object, {send,recv}_tensor_dict. Also added new test for {send,recv}.


# Send object
torch.distributed.send(object_tensor, dst=dst, group=group)

return None

def recv_object(self,
src: int,
group: Optional[ProcessGroup] = None) -> Any:
"""Receive the input object list from the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"

assert src != self.rank, (
"Invalid source rank. Source rank is the same as the current rank."
)

current_device = _get_pg_default_device(group)

size_tensor = torch.empty(1, dtype=torch.long, device=current_device)

# Receive object size
rank_size = torch.distributed.recv(size_tensor, src=src, group=group)

# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=current_device)

rank_object = torch.distributed.recv(object_tensor,
src=src,
group=group)
assert (rank_size == rank_object
), "Mismatch in return ranks for object sizes and objects."
# Deserialize objects using their stored sizes.

return _tensor_to_object(object_tensor, size_tensor.item(), group)

def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
Expand Down Expand Up @@ -433,6 +493,90 @@ def broadcast_tensor_dict(
async_handle.wait()
return tensor_dict

def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict

group = self.device_group
metadata_group = self.cpu_group

if dst is None:
dst = self.next_rank
assert dst < self.world_size, f"Invalid dst rank ({dst})"
dst = self.ranks[dst]

metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst, group=metadata_group)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group)
return None

def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None

group = self.device_group
metadata_group = self.cpu_group

if src is None:
src = self.prev_rank
assert src < self.world_size, f"Invalid src rank ({src})"
src = self.ranks[src]

recv_metadata_list = self.recv_object(src=src, group=metadata_group)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=src,
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict

def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
Expand All @@ -442,6 +586,26 @@ def barrier(self):
"""
torch.distributed.barrier(group=self.cpu_group)

def send(self, tensor: torch.Tensor, dst: int) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor)
else:
torch.distributed.isend(tensor, dst, self.device_group)

def recv(self, size: torch.Size, dtype: torch.dtype,
src: int) -> torch.Tensor:
"""Receives a tensor from the src rank."""
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor)
else:
req = torch.distributed.irecv(tensor, src, self.device_group)
req.wait()
return tensor

def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
Expand Down Expand Up @@ -684,6 +848,28 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group


def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
return get_pp_group().world_size


def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
return get_pp_group().rank_in_group

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two are not used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean they are not used currently? Planning to use them in next PRs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, they are not used. in the future i will remove legacy usage like get_tensor_model_parallel_rank, as pointed in #5293 (comment) . The basic idea is that, "users" of parallel_state can assemble the functionality they want, rather than keep adding new helper functions in parallel_state .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, basically everything is done through GroupCoordinator. Removed all the extra helper functions.


def is_pipeline_model_parallel_first_rank():
"""Return True if the rank is the first rank in the
pipeline model parallel group."""
return get_pp_group().rank_in_group == 0


def is_pipeline_model_parallel_last_rank():
"""Return True if the rank is the last rank in the
pipeline model parallel group."""
return get_pp_group().rank_in_group == get_pp_group().world_size - 1


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add them in the GroupCoordinator.is_first_rank, so that tp group might also use it as well in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
Expand Down
Loading