From e08e665e96dceb9da47e12defd6ffe57693c908d Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Thu, 20 Jun 2024 17:43:23 +0000 Subject: [PATCH 01/14] Add send and recv helpers Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 56 +++++++- tests/distributed/test_custom_all_reduce.py | 5 +- tests/utils.py | 2 +- vllm/distributed/communication_op.py | 29 +++- vllm/distributed/object_list_ops.py | 123 ++++++++++++++++ vllm/distributed/parallel_state.py | 152 ++++++++++++++++++++ 6 files changed, 359 insertions(+), 8 deletions(-) create mode 100644 vllm/distributed/object_list_ops.py diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 53654dc40d10d..86e1ea8161e91 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -9,11 +9,13 @@ import torch from vllm.distributed import (broadcast_tensor_dict, + is_pipeline_model_parallel_first_rank, + is_pipeline_model_parallel_last_rank, + recv_tensor_dict, send_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from ..utils import (init_test_distributed_environment, - multi_process_tensor_parallel) +from ..utils import init_test_distributed_environment, multi_process_parallel @ray.remote(num_gpus=1, max_calls=1) @@ -105,6 +107,46 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, assert torch.allclose(recv_dict["f"], test_dict["f"]) +@ray.remote(num_gpus=1, max_calls=1) +def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, + distributed_init_port: str): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + test_dict = { + # device tensor + "a": torch.arange(8, dtype=torch.float32, device="cuda"), + # CPU tensor + "b": torch.arange(16, dtype=torch.int8, device="cpu"), + "c": "test", + "d": [1, 2, 3], + "e": { + "a": 1, + "b": 2 + }, + # empty tensor + "f": torch.tensor([], dtype=torch.float32, device="cuda"), + } + + if not is_pipeline_model_parallel_first_rank(): + recv_dict = recv_tensor_dict() + + if not is_pipeline_model_parallel_last_rank(): + send_tensor_dict(test_dict) + + if not is_pipeline_model_parallel_first_rank(): + assert len(recv_dict) == len(test_dict) + assert torch.allclose(recv_dict["a"], test_dict["a"]) + assert torch.allclose(recv_dict["b"], test_dict["b"]) + assert recv_dict["c"] == test_dict["c"] + assert recv_dict["d"] == test_dict["d"] + assert recv_dict["e"] == test_dict["e"] + assert torch.allclose(recv_dict["f"], test_dict["f"]) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("tp_size", [2]) @@ -113,4 +155,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, broadcast_tensor_dict_test_worker ]) def test_multi_process_tensor_parallel(tp_size, test_target): - multi_process_tensor_parallel(tp_size, 1, test_target) + multi_process_parallel(tp_size, 1, test_target) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("pp_size", [2]) +@pytest.mark.parametrize("test_target", [send_recv_tensor_dict_test_worker]) +def test_multi_process_pipeline_parallel(pp_size, test_target): + multi_process_parallel(1, pp_size, test_target) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 9a39160b8a462..3c281a45fcaf1 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -12,8 +12,7 @@ get_tp_group, graph_capture) from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, - multi_process_tensor_parallel) + init_test_distributed_environment, multi_process_parallel) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] @@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") - multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target) + multi_process_parallel(tp_size, pipeline_parallel_size, test_target) diff --git a/tests/utils.py b/tests/utils.py index bc30515c83100..174efca4af532 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -129,7 +129,7 @@ def init_test_distributed_environment( ensure_model_parallel_initialized(tp_size, pp_size) -def multi_process_tensor_parallel( +def multi_process_parallel( tp_size: int, pp_size: int, test_target, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 32394a07b00b9..aa3ad951d3a1e 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -3,7 +3,7 @@ import torch import torch.distributed -from .parallel_state import get_tp_group +from .parallel_state import get_pp_group, get_tp_group def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -30,3 +30,30 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, if not torch.distributed.is_initialized(): return tensor_dict return get_tp_group().broadcast_tensor_dict(tensor_dict, src) + + +def send_tensor_dict(tensors: Dict[str, torch.Tensor], + dst: Optional[int] = None) -> None: + """ + Send the tensors to the next pipeline model parallel rank. + Args: + tensors (Dict[torch.Tensor]): Dict of tensors to send. + """ + if dst is None: + dst = get_pp_group().next_rank + get_pp_group().send_tensor_dict(tensors, dst) + + +def recv_tensor_dict( + src: Optional[int] = None +) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + """ + Receive tensors from the previous pipeline model parallel rank assuming all + tensors are the same size. + Returns: + Dict[torch.Tensor]: Dict of received tensors. + """ + if src is None: + src = get_pp_group().prev_rank + tensors = get_pp_group().recv_tensor_dict(src) + return tensors diff --git a/vllm/distributed/object_list_ops.py b/vllm/distributed/object_list_ops.py new file mode 100644 index 0000000000000..309ff3196eaf5 --- /dev/null +++ b/vllm/distributed/object_list_ops.py @@ -0,0 +1,123 @@ +""" +This file is necessary until new version of torch.distributed is released with +https://github.com/pytorch/pytorch/commit/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc +""" +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import (_get_pg_default_device, + _object_to_tensor, + _tensor_to_object) + + +def send_object_list(object_list, dst, group=None, device=None): + """ + Sends picklable objects in ``object_list`` synchronously. + + Similar to :func:`send`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + sent. + + Args: + object_list (List[Any]): List of input objects to sent. + Each object must be picklable. Receiver must provide lists of + equal sizes. + dst (int): Destination rank to send ``object_list`` to. + Destination rank is based on global process group + (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before sending. Default is ``None``. + + Returns: + ``None``. + """ + if dist.get_rank() == dst: + raise ValueError( + "Invalid destination rank: destination rank should not be the " + "same as the rank of the current process.") + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or _get_pg_default_device(group) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip( + * + [_object_to_tensor(obj, current_device, group) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + + # Send object sizes + dist.send(object_sizes_tensor, dst=dst, group=group) + + # Concatenate and send serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, + # if the tensor_list has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + + dist.send(object_tensor, dst=dst, group=group) + + +def recv_object_list(object_list, src=None, group=None, device=None): + """ + Receives picklable objects in ``object_list`` synchronously. + + Similar to :func:`recv`, but can receive Python objects. + + Args: + object_list (List[Any]): List of objects to receive into. + Must provide a list of sizes equal to the size of the list + being sent. + src (int, optional): Source rank from which to recv ``object_list``. + Source rank is based on global process group + (regardless of ``group`` argument) + Will receive from any rank if set to None. Default is ``None``. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, receives on + this device. Default is ``None``. + + Returns: + Sender rank. -1 if rank is not part of the group. If rank is part + of the group, ``object_list`` will contain the sent objects from + ``src`` rank. + """ + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or _get_pg_default_device(group) + object_sizes_tensor = torch.empty(len(object_list), + dtype=torch.long, + device=current_device) + + # Receive object sizes + rank_sizes = dist.recv(object_sizes_tensor, src=src, group=group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device) + + rank_objects = dist.recv(object_tensor, src=src, group=group) + assert (rank_sizes == rank_objects + ), "Mismatch in return ranks for object sizes and objects." + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset:offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 02b0dcbcb6b24..33dfa8717429b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,6 +33,8 @@ import vllm.envs as envs from vllm.logger import init_logger +from .object_list_ops import recv_object_list, send_object_list + @dataclass class GraphCaptureContext: @@ -342,6 +344,38 @@ def broadcast_object_list(self, group=self.device_group) return obj_list + def send_object_list(self, + obj_list: List[Any], + dst: int, + group: Optional[ProcessGroup] = None): + """Send the input object list to the destination rank.""" + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Send. + send_object_list(obj_list, + dst=self.ranks[dst], + group=self.device_group) + return obj_list + + def recv_object_list(self, + obj_list: List[Any], + src: int, + group: Optional[ProcessGroup] = None): + """Receive the input object list from the source rank.""" + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Receive. + recv_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + def broadcast_tensor_dict( self, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, @@ -433,6 +467,82 @@ def broadcast_tensor_dict( async_handle.wait() return tensor_dict + def send_tensor_dict( + self, tensor_dict: Dict[Any, Union[torch.Tensor, Any]], + dst: int) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert dst < self.world_size, f"Invalid dst rank ({dst})" + dst = self.ranks[dst] + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + send_object_list([metadata_list], dst=dst, group=metadata_group) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, dst=dst, group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=dst, group=group) + return None + + def recv_tensor_dict( + self, src: int) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + recv_metadata_list = [None] + recv_object_list(recv_metadata_list, src=src, group=metadata_group) + assert recv_metadata_list[0] is not None + tensor_dict = {} + for key, value in recv_metadata_list[0]: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=src, + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=src, group=group) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + def barrier(self): """Barrier synchronization among the group. NOTE: don't use `device_group` here! `barrier` in NCCL is @@ -442,6 +552,26 @@ def barrier(self): """ torch.distributed.barrier(group=self.cpu_group) + def send(self, tensor: torch.Tensor, dst: int) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor) + else: + torch.distributed.isend(tensor, dst, self.device_group) + + def recv(self, size: torch.Size, dtype: torch.dtype, + src: int) -> torch.Tensor: + """Receives a tensor from the src rank.""" + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor) + else: + req = torch.distributed.irecv(tensor, src, self.device_group) + req.wait() + return tensor + def destroy(self): if self.device_group is not None: torch.distributed.destroy_process_group(self.device_group) @@ -684,6 +814,28 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return get_pp_group().world_size + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return get_pp_group().rank_in_group + + +def is_pipeline_model_parallel_first_rank(): + """Return True if the rank is the first rank in the + pipeline model parallel group.""" + return get_pp_group().rank_in_group == 0 + + +def is_pipeline_model_parallel_last_rank(): + """Return True if the rank is the last rank in the + pipeline model parallel group.""" + return get_pp_group().rank_in_group == get_pp_group().world_size - 1 + + def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP From ae4a17a20a3efaee3e1a30896d951b19d05fae4d Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Thu, 20 Jun 2024 19:11:08 +0000 Subject: [PATCH 02/14] Fix comments Signed-off-by: Muralidhar Andoorveedu --- vllm/distributed/parallel_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 33dfa8717429b..f51fa19709d88 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -488,12 +488,12 @@ def send_tensor_dict( dict), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. - # `broadcast_object_list` has serialization & deserialization, + # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. send_object_list([metadata_list], dst=dst, group=metadata_group) for tensor in tensor_list: if tensor.numel() == 0: - # Skip broadcasting empty tensors. + # Skip sending empty tensors. continue if tensor.is_cpu: # use metadata_group for CPU tensors From c34f44090b635faecbfb947402dab5d5a892df1d Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Thu, 20 Jun 2024 20:20:05 +0000 Subject: [PATCH 03/14] Change send and recv object to send and recv object_lists Signed-off-by: Muralidhar Andoorveedu --- vllm/distributed/object_list_ops.py | 123 ---------------------------- vllm/distributed/parallel_state.py | 78 ++++++++++++------ 2 files changed, 51 insertions(+), 150 deletions(-) delete mode 100644 vllm/distributed/object_list_ops.py diff --git a/vllm/distributed/object_list_ops.py b/vllm/distributed/object_list_ops.py deleted file mode 100644 index 309ff3196eaf5..0000000000000 --- a/vllm/distributed/object_list_ops.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -This file is necessary until new version of torch.distributed is released with -https://github.com/pytorch/pytorch/commit/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc -""" -import torch -import torch.distributed as dist -from torch.distributed.distributed_c10d import (_get_pg_default_device, - _object_to_tensor, - _tensor_to_object) - - -def send_object_list(object_list, dst, group=None, device=None): - """ - Sends picklable objects in ``object_list`` synchronously. - - Similar to :func:`send`, but Python objects can be passed in. - Note that all objects in ``object_list`` must be picklable in order to be - sent. - - Args: - object_list (List[Any]): List of input objects to sent. - Each object must be picklable. Receiver must provide lists of - equal sizes. - dst (int): Destination rank to send ``object_list`` to. - Destination rank is based on global process group - (regardless of ``group`` argument) - group: (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - device (``torch.device``, optional): If not None, the objects are - serialized and converted to tensors which are moved to the - ``device`` before sending. Default is ``None``. - - Returns: - ``None``. - """ - if dist.get_rank() == dst: - raise ValueError( - "Invalid destination rank: destination rank should not be the " - "same as the rank of the current process.") - - # Current device selection. - # To preserve backwards compatibility, ``device`` is default to ``None`` - # in which case we run current logic of device selection, i.e. - # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the - # case it is not ``None`` we move the size and object tensors to be - # sent to this device. - current_device = device or _get_pg_default_device(group) - # Serialize object_list elements to tensors on src rank. - tensor_list, size_list = zip( - * - [_object_to_tensor(obj, current_device, group) for obj in object_list]) - object_sizes_tensor = torch.cat(size_list) - - # Send object sizes - dist.send(object_sizes_tensor, dst=dst, group=group) - - # Concatenate and send serialized object tensors - # Note: torch.cat will do an extra memory copy to the current device, - # if the tensor_list has only one element, we can skip the copy. - if len(tensor_list) == 1: # type: ignore[possibly-undefined] - object_tensor = tensor_list[0] - else: - object_tensor = torch.cat(tensor_list) - - dist.send(object_tensor, dst=dst, group=group) - - -def recv_object_list(object_list, src=None, group=None, device=None): - """ - Receives picklable objects in ``object_list`` synchronously. - - Similar to :func:`recv`, but can receive Python objects. - - Args: - object_list (List[Any]): List of objects to receive into. - Must provide a list of sizes equal to the size of the list - being sent. - src (int, optional): Source rank from which to recv ``object_list``. - Source rank is based on global process group - (regardless of ``group`` argument) - Will receive from any rank if set to None. Default is ``None``. - group: (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - device (``torch.device``, optional): If not None, receives on - this device. Default is ``None``. - - Returns: - Sender rank. -1 if rank is not part of the group. If rank is part - of the group, ``object_list`` will contain the sent objects from - ``src`` rank. - """ - - # Current device selection. - # To preserve backwards compatibility, ``device`` is default to ``None`` - # in which case we run current logic of device selection, i.e. - # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the - # case it is not ``None`` we move the size and object tensors to be - # received to this device. - current_device = device or _get_pg_default_device(group) - object_sizes_tensor = torch.empty(len(object_list), - dtype=torch.long, - device=current_device) - - # Receive object sizes - rank_sizes = dist.recv(object_sizes_tensor, src=src, group=group) - - # Tensor to receive serialized objects into. - object_tensor = torch.empty( # type: ignore[call-overload] - torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] - dtype=torch.uint8, - device=current_device) - - rank_objects = dist.recv(object_tensor, src=src, group=group) - assert (rank_sizes == rank_objects - ), "Mismatch in return ranks for object sizes and objects." - # Deserialize objects using their stored sizes. - offset = 0 - for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] - obj_view = obj_view.type(torch.uint8) - offset += obj_size - object_list[i] = _tensor_to_object(obj_view, obj_size, group) - return rank_objects diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f51fa19709d88..2d9ea6557fa4a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -28,13 +28,15 @@ from unittest.mock import patch import torch +import torch.distributed from torch.distributed import Backend, ProcessGroup +from torch.distributed.distributed_c10d import (_get_pg_default_device, + _object_to_tensor, + _tensor_to_object) import vllm.envs as envs from vllm.logger import init_logger -from .object_list_ops import recv_object_list, send_object_list - @dataclass class GraphCaptureContext: @@ -344,37 +346,61 @@ def broadcast_object_list(self, group=self.device_group) return obj_list - def send_object_list(self, - obj_list: List[Any], - dst: int, - group: Optional[ProcessGroup] = None): + def send_object(self, + obj: Any, + dst: int, + group: Optional[ProcessGroup] = None) -> None: """Send the input object list to the destination rank.""" assert dst < self.world_size, f"Invalid dst rank ({dst})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: - return obj_list - # Send. - send_object_list(obj_list, - dst=self.ranks[dst], - group=self.device_group) - return obj_list + return obj + + current_device = _get_pg_default_device(group) + # Serialize object to tensor. + object_tensor, size_tensor = _object_to_tensor(obj, current_device, + group) + + # Send object size + torch.distributed.send(size_tensor, dst=dst, group=group) - def recv_object_list(self, - obj_list: List[Any], - src: int, - group: Optional[ProcessGroup] = None): + # Send object + torch.distributed.send(object_tensor, dst=dst, group=group) + + return None + + def recv_object(self, + src: int, + group: Optional[ProcessGroup] = None) -> Any: """Receive the input object list from the source rank.""" assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: - return obj_list - # Receive. - recv_object_list(obj_list, - src=self.ranks[src], - group=self.device_group) - return obj_list + return None + + current_device = _get_pg_default_device(group) + + size_tensor = torch.empty(1, dtype=torch.long, device=current_device) + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, src=src, group=group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device) + + rank_object = torch.distributed.recv(object_tensor, + src=src, + group=group) + assert (rank_size == rank_object + ), "Mismatch in return ranks for object sizes and objects." + # Deserialize objects using their stored sizes. + + return _tensor_to_object(object_tensor, size_tensor.item(), group) def broadcast_tensor_dict( self, @@ -490,7 +516,7 @@ def send_tensor_dict( # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. - send_object_list([metadata_list], dst=dst, group=metadata_group) + self.send_object(metadata_list, dst=dst, group=metadata_group) for tensor in tensor_list: if tensor.numel() == 0: # Skip sending empty tensors. @@ -517,11 +543,9 @@ def recv_tensor_dict( assert src < self.world_size, f"Invalid src rank ({src})" src = self.ranks[src] - recv_metadata_list = [None] - recv_object_list(recv_metadata_list, src=src, group=metadata_group) - assert recv_metadata_list[0] is not None + recv_metadata_list = self.recv_object(src=src, group=metadata_group) tensor_dict = {} - for key, value in recv_metadata_list[0]: + for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, From 8b16ef129b0d6cb77bab5b2795169c00d007dac4 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 00:43:56 +0000 Subject: [PATCH 04/14] Remove module level function Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 6 +++--- vllm/distributed/communication_op.py | 27 --------------------------- vllm/distributed/parallel_state.py | 10 ++++++++-- 3 files changed, 11 insertions(+), 32 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 86e1ea8161e91..b96731d8a1b7d 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,7 +11,7 @@ from vllm.distributed import (broadcast_tensor_dict, is_pipeline_model_parallel_first_rank, is_pipeline_model_parallel_last_rank, - recv_tensor_dict, send_tensor_dict, + get_pp_group, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) @@ -132,10 +132,10 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, } if not is_pipeline_model_parallel_first_rank(): - recv_dict = recv_tensor_dict() + recv_dict = get_pp_group().recv_tensor_dict() if not is_pipeline_model_parallel_last_rank(): - send_tensor_dict(test_dict) + get_pp_group().send_tensor_dict(test_dict) if not is_pipeline_model_parallel_first_rank(): assert len(recv_dict) == len(test_dict) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index aa3ad951d3a1e..45b0c2cd2fd70 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -30,30 +30,3 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, if not torch.distributed.is_initialized(): return tensor_dict return get_tp_group().broadcast_tensor_dict(tensor_dict, src) - - -def send_tensor_dict(tensors: Dict[str, torch.Tensor], - dst: Optional[int] = None) -> None: - """ - Send the tensors to the next pipeline model parallel rank. - Args: - tensors (Dict[torch.Tensor]): Dict of tensors to send. - """ - if dst is None: - dst = get_pp_group().next_rank - get_pp_group().send_tensor_dict(tensors, dst) - - -def recv_tensor_dict( - src: Optional[int] = None -) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: - """ - Receive tensors from the previous pipeline model parallel rank assuming all - tensors are the same size. - Returns: - Dict[torch.Tensor]: Dict of received tensors. - """ - if src is None: - src = get_pp_group().prev_rank - tensors = get_pp_group().recv_tensor_dict(src) - return tensors diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2d9ea6557fa4a..293a1f1d5a7b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -495,7 +495,7 @@ def broadcast_tensor_dict( def send_tensor_dict( self, tensor_dict: Dict[Any, Union[torch.Tensor, Any]], - dst: int) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + dst: Optional[int] = None) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ @@ -505,6 +505,9 @@ def send_tensor_dict( group = self.device_group metadata_group = self.cpu_group + + if dst is None: + dst = self.next_rank assert dst < self.world_size, f"Invalid dst rank ({dst})" dst = self.ranks[dst] @@ -530,7 +533,7 @@ def send_tensor_dict( return None def recv_tensor_dict( - self, src: int) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + self, src: Optional[int] = None) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ @@ -540,6 +543,9 @@ def recv_tensor_dict( group = self.device_group metadata_group = self.cpu_group + + if src is None: + src = self.prev_rank assert src < self.world_size, f"Invalid src rank ({src})" src = self.ranks[src] From ac1347246ae6fa67969ee789d9b8b2fd16b639a6 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 00:45:48 +0000 Subject: [PATCH 05/14] Format and lint Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 3 +-- vllm/distributed/communication_op.py | 2 +- vllm/distributed/parallel_state.py | 10 +++++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index b96731d8a1b7d..cfd839971bd1e 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,8 +11,7 @@ from vllm.distributed import (broadcast_tensor_dict, is_pipeline_model_parallel_first_rank, is_pipeline_model_parallel_last_rank, - get_pp_group, - tensor_model_parallel_all_gather, + get_pp_group, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from ..utils import init_test_distributed_environment, multi_process_parallel diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 45b0c2cd2fd70..32394a07b00b9 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -3,7 +3,7 @@ import torch import torch.distributed -from .parallel_state import get_pp_group, get_tp_group +from .parallel_state import get_tp_group def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 293a1f1d5a7b8..05de19811bc76 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -494,8 +494,10 @@ def broadcast_tensor_dict( return tensor_dict def send_tensor_dict( - self, tensor_dict: Dict[Any, Union[torch.Tensor, Any]], - dst: Optional[int] = None) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + self, + tensor_dict: Dict[Any, Union[torch.Tensor, Any]], + dst: Optional[int] = None + ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. """ @@ -533,7 +535,9 @@ def send_tensor_dict( return None def recv_tensor_dict( - self, src: Optional[int] = None) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + self, + src: Optional[int] = None + ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ From 68cd3c3b584c51cb9cfb2098b6ece30054b7626a Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 00:46:07 +0000 Subject: [PATCH 06/14] Format Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index cfd839971bd1e..479f87a0d74ae 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,10 +8,10 @@ import ray import torch -from vllm.distributed import (broadcast_tensor_dict, +from vllm.distributed import (broadcast_tensor_dict, get_pp_group, is_pipeline_model_parallel_first_rank, is_pipeline_model_parallel_last_rank, - get_pp_group, tensor_model_parallel_all_gather, + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from ..utils import init_test_distributed_environment, multi_process_parallel From 1c57591ee9626618e32dcc3de6b7b46f13911a2d Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 01:03:47 +0000 Subject: [PATCH 07/14] Remove world size check and add assert for rank. Signed-off-by: Muralidhar Andoorveedu --- vllm/distributed/parallel_state.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 05de19811bc76..eac7eb81312e4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -353,9 +353,9 @@ def send_object(self, """Send the input object list to the destination rank.""" assert dst < self.world_size, f"Invalid dst rank ({dst})" - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") current_device = _get_pg_default_device(group) # Serialize object to tensor. @@ -376,9 +376,9 @@ def recv_object(self, """Receive the input object list from the source rank.""" assert src < self.world_size, f"Invalid src rank ({src})" - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return None + assert src != self.rank, ( + "Invalid source rank. Source rank is the same as the current rank." + ) current_device = _get_pg_default_device(group) From 2c9fa20951f6e2c365e8f299443e3e8f3a296e76 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 02:39:43 +0000 Subject: [PATCH 08/14] Simplifications Signed-off-by: Muralidhar Andoorveedu --- vllm/distributed/parallel_state.py | 60 +++++++++++++++--------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index eac7eb81312e4..2ee62523c4d85 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -20,6 +20,7 @@ steps. """ import contextlib +import pickle from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass @@ -30,9 +31,6 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -from torch.distributed.distributed_c10d import (_get_pg_default_device, - _object_to_tensor, - _tensor_to_object) import vllm.envs as envs from vllm.logger import init_logger @@ -346,10 +344,7 @@ def broadcast_object_list(self, group=self.device_group) return obj_list - def send_object(self, - obj: Any, - dst: int, - group: Optional[ProcessGroup] = None) -> None: + def send_object(self, obj: Any, dst: int) -> None: """Send the input object list to the destination rank.""" assert dst < self.world_size, f"Invalid dst rank ({dst})" @@ -357,22 +352,27 @@ def send_object(self, "Invalid destination rank. Destination rank is the same " "as the current rank.") - current_device = _get_pg_default_device(group) - # Serialize object to tensor. - object_tensor, size_tensor = _object_to_tensor(obj, current_device, - group) + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") # Send object size - torch.distributed.send(size_tensor, dst=dst, group=group) + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) # Send object - torch.distributed.send(object_tensor, dst=dst, group=group) + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) return None - def recv_object(self, - src: int, - group: Optional[ProcessGroup] = None) -> Any: + def recv_object(self, src: int) -> Any: """Receive the input object list from the source rank.""" assert src < self.world_size, f"Invalid src rank ({src})" @@ -380,27 +380,29 @@ def recv_object(self, "Invalid source rank. Source rank is the same as the current rank." ) - current_device = _get_pg_default_device(group) - - size_tensor = torch.empty(1, dtype=torch.long, device=current_device) + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") # Receive object size - rank_size = torch.distributed.recv(size_tensor, src=src, group=group) + rank_size = torch.distributed.recv(size_tensor, + src=src, + group=self.cpu_group) # Tensor to receive serialized objects into. object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device=current_device) + device="cpu") rank_object = torch.distributed.recv(object_tensor, src=src, - group=group) - assert (rank_size == rank_object - ), "Mismatch in return ranks for object sizes and objects." - # Deserialize objects using their stored sizes. + group=self.cpu_group) - return _tensor_to_object(object_tensor, size_tensor.item(), group) + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj def broadcast_tensor_dict( self, @@ -511,7 +513,6 @@ def send_tensor_dict( if dst is None: dst = self.next_rank assert dst < self.world_size, f"Invalid dst rank ({dst})" - dst = self.ranks[dst] metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( @@ -521,7 +522,7 @@ def send_tensor_dict( # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. - self.send_object(metadata_list, dst=dst, group=metadata_group) + self.send_object(metadata_list, dst=dst) for tensor in tensor_list: if tensor.numel() == 0: # Skip sending empty tensors. @@ -551,9 +552,8 @@ def recv_tensor_dict( if src is None: src = self.prev_rank assert src < self.world_size, f"Invalid src rank ({src})" - src = self.ranks[src] - recv_metadata_list = self.recv_object(src=src, group=metadata_group) + recv_metadata_list = self.recv_object(src=src) tensor_dict = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): From 25ad2bbc6be5861da8be26b5742386df536beb42 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 02:58:48 +0000 Subject: [PATCH 09/14] Streamed send/recv APIs with local ranks Signed-off-by: Muralidhar Andoorveedu --- .../device_communicators/pynccl.py | 14 ++------- vllm/distributed/parallel_state.py | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 83eec264b6f81..7319566545678 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -121,10 +121,7 @@ def all_reduce(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) - def send(self, - tensor: torch.Tensor, - dst: Optional[int] = None, - stream=None): + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( @@ -132,16 +129,11 @@ def send(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - if dst is None: - dst = (self.rank + 1) % self.world_size self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream)) - def recv(self, - tensor: torch.Tensor, - src: Optional[int] = None, - stream=None): + def recv(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( @@ -149,8 +141,6 @@ def recv(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - if src is None: - src = (self.rank - 1) % self.world_size self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2ee62523c4d85..7f8ae3ad8c3ee 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -346,6 +346,8 @@ def broadcast_object_list(self, def send_object(self, obj: Any, dst: int) -> None: """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + assert dst < self.world_size, f"Invalid dst rank ({dst})" assert dst != self.rank, ( @@ -374,6 +376,8 @@ def send_object(self, obj: Any, dst: int) -> None: def recv_object(self, src: int) -> Any: """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + assert src < self.world_size, f"Invalid src rank ({src})" assert src != self.rank, ( @@ -586,24 +590,33 @@ def barrier(self): """ torch.distributed.barrier(group=self.cpu_group) - def send(self, tensor: torch.Tensor, dst: int) -> None: + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = self.next_rank + pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.send(tensor) + pynccl_comm.send(tensor, dst) else: - torch.distributed.isend(tensor, dst, self.device_group) + torch.distributed.send(tensor, self.ranks[dst], self.device_group) - def recv(self, size: torch.Size, dtype: torch.dtype, - src: int) -> torch.Tensor: + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + if src is None: + src = self.prev_rank + tensor = torch.empty(size, dtype=dtype, device=self.device) pynccl_comm = self.pynccl_comm if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.recv(tensor) + pynccl_comm.recv(tensor, src) else: - req = torch.distributed.irecv(tensor, src, self.device_group) - req.wait() + torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor def destroy(self): From 31ce144feedb6239b36cc1c4435ab42aeb3e6016 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 03:07:06 +0000 Subject: [PATCH 10/14] Refactor send and recv functions and add new test Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 479f87a0d74ae..7de562cc90f01 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -146,6 +146,28 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, assert torch.allclose(recv_dict["f"], test_dict["f"]) +@ray.remote(num_gpus=1, max_calls=1) +def send_recv_test_worker(tp_size: int, pp_size: int, rank: int, + distributed_init_port: str): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + size = 64 + test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") + + if not is_pipeline_model_parallel_first_rank(): + recv_tensor = get_pp_group().recv(size, dtype=torch.float32) + + if not is_pipeline_model_parallel_last_rank(): + get_pp_group().send(test_tensor) + + if not is_pipeline_model_parallel_first_rank(): + assert torch.allclose(test_tensor, recv_tensor) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("tp_size", [2]) @@ -160,6 +182,7 @@ def test_multi_process_tensor_parallel(tp_size, test_target): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("pp_size", [2]) -@pytest.mark.parametrize("test_target", [send_recv_tensor_dict_test_worker]) +@pytest.mark.parametrize( + "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) def test_multi_process_pipeline_parallel(pp_size, test_target): multi_process_parallel(1, pp_size, test_target) From 83448fbad6883dce5a5db1a99edcb56dc404fe4b Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 03:15:43 +0000 Subject: [PATCH 11/14] Add is_first_rank, is_last_rank, remove helper functions Signed-off-by: Muralidhar Andoorveedu --- vllm/distributed/parallel_state.py | 32 ++++++++++-------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 7f8ae3ad8c3ee..b011bed729995 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -174,6 +174,16 @@ def last_rank(self): """Return the global rank of the last process in the group""" return self.ranks[-1] + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + @property def next_rank(self): """Return the global rank of the process that follows the caller""" @@ -861,28 +871,6 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group -def get_pipeline_model_parallel_world_size(): - """Return world size for the pipeline model parallel group.""" - return get_pp_group().world_size - - -def get_pipeline_model_parallel_rank(): - """Return my rank for the pipeline model parallel group.""" - return get_pp_group().rank_in_group - - -def is_pipeline_model_parallel_first_rank(): - """Return True if the rank is the first rank in the - pipeline model parallel group.""" - return get_pp_group().rank_in_group == 0 - - -def is_pipeline_model_parallel_last_rank(): - """Return True if the rank is the last rank in the - pipeline model parallel group.""" - return get_pp_group().rank_in_group == get_pp_group().world_size - 1 - - def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP From 8b4e1fd3bd56d0bcd65a9004978905dda4bbd7f0 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 03:21:50 +0000 Subject: [PATCH 12/14] Update test Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 7de562cc90f01..b0e95347daae3 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -9,8 +9,6 @@ import torch from vllm.distributed import (broadcast_tensor_dict, get_pp_group, - is_pipeline_model_parallel_first_rank, - is_pipeline_model_parallel_last_rank, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) @@ -130,13 +128,13 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, "f": torch.tensor([], dtype=torch.float32, device="cuda"), } - if not is_pipeline_model_parallel_first_rank(): + if not get_pp_group().is_first_rank: recv_dict = get_pp_group().recv_tensor_dict() - if not is_pipeline_model_parallel_last_rank(): + if not get_pp_group().is_last_rank: get_pp_group().send_tensor_dict(test_dict) - if not is_pipeline_model_parallel_first_rank(): + if not get_pp_group().is_first_rank: assert len(recv_dict) == len(test_dict) assert torch.allclose(recv_dict["a"], test_dict["a"]) assert torch.allclose(recv_dict["b"], test_dict["b"]) @@ -158,13 +156,13 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int, size = 64 test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") - if not is_pipeline_model_parallel_first_rank(): + if not get_pp_group().is_first_rank: recv_tensor = get_pp_group().recv(size, dtype=torch.float32) - if not is_pipeline_model_parallel_last_rank(): + if not get_pp_group().is_last_rank: get_pp_group().send(test_tensor) - if not is_pipeline_model_parallel_first_rank(): + if not get_pp_group().is_first_rank: assert torch.allclose(test_tensor, recv_tensor) From 8cd5a9c4376d339d4a2129ae12f593e7e4dca703 Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 03:24:32 +0000 Subject: [PATCH 13/14] Format Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_comm_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index b0e95347daae3..bf0f31df02fa5 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -156,7 +156,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int, size = 64 test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") - if not get_pp_group().is_first_rank: + if not get_pp_group().is_first_rank: recv_tensor = get_pp_group().recv(size, dtype=torch.float32) if not get_pp_group().is_last_rank: From 9d085951e6e6148ab045ed942e4e210e50d7824d Mon Sep 17 00:00:00 2001 From: Muralidhar Andoorveedu Date: Sat, 22 Jun 2024 22:00:44 +0000 Subject: [PATCH 14/14] chore: Refactor send and recv functions for pynccl tests Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_pynccl.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 964dbc5423e75..e0e424439e3a5 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -168,9 +168,13 @@ def send_recv_worker_fn(): dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): if pynccl_comm.rank == 0: - pynccl_comm.send(tensor) + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) else: - pynccl_comm.recv(tensor) + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) result = tensor.mean().cpu().item() assert result == 1 @@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn(): device=device) with pynccl_comm.change_state(enable=True): if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor) + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) else: - pynccl_comm.recv(tensor) + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) result = tensor.mean().cpu().item() if torch.distributed.get_rank() in [0, 2]: assert result == 1