diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index f22185fd1d72..eb48a4f8cc8a 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -9,6 +9,7 @@ from typing import NamedTuple from ray.experimental.channel.cached_channel import CachedChannel +from ray.experimental.channel.gpu_communicator import GPUCommunicator import ray from ray.exceptions import RayTaskError, RayChannelError from ray.util.annotations import PublicAPI @@ -640,6 +641,11 @@ def __init__( # Type hints specified by the user for DAG (intermediate) outputs. self._type_hints = [] + # This is set to true when type hint of `transport="nccl"`` is used + self._use_default_nccl_group = False + # This is set to the specified custom nccl group + # if there exists a type hint of `transport=nccl_group` + self._custom_nccl_group: Optional[GPUCommunicator] = None # Uniquely identifies the NCCL communicator that will be used within # this DAG, if any. self._nccl_group_id: Optional[str] = None @@ -806,6 +812,33 @@ def _preprocess(self) -> None: if dag_node.type_hint.requires_nccl(): # Add all writers to the NCCL group. nccl_actors.add(actor_handle) + custom_nccl_group = dag_node.type_hint.get_custom_nccl_group() + mixed_nccl_group_error_message = ( + "Accelerated DAGs do not support mixed usage of " + "type hints of default NCCL group " + '(i.e., TorchTensor(transport="nccl"))' + "and custom NCCL group " + "(i.e., TorchTensor(transport=nccl_group)). " + "Please check all the TorchTensor type hints and " + "make sure only one type of NCCL transport is specified." + ) + if custom_nccl_group is None: + if self._custom_nccl_group is not None: + raise ValueError(mixed_nccl_group_error_message) + self._use_default_nccl_group = True + else: + if self._use_default_nccl_group: + raise ValueError(mixed_nccl_group_error_message) + if self._custom_nccl_group is not None: + if self._custom_nccl_group != custom_nccl_group: + raise ValueError( + "Accelerated DAGs currently only support " + "a single custom NCCL group, but multiple " + "have been specified. Check all the " + "TorchTensor(transport=nccl_group) type hints " + "to make sure only one NCCL group is used." + ) + self._custom_nccl_group = custom_nccl_group elif isinstance(dag_node, InputNode): if dag_node.type_hint.requires_nccl(): raise ValueError( @@ -916,7 +949,7 @@ def _preprocess(self) -> None: if None in nccl_actors: raise ValueError("Driver cannot participate in the NCCL group.") if nccl_actors and self._nccl_group_id is None: - self._nccl_group_id = _init_nccl_group(nccl_actors) + self._nccl_group_id = _init_nccl_group(nccl_actors, self._custom_nccl_group) if direct_input: self._input_num_positional_args = 1 diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 1331bc086497..edb089440d8d 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -3,6 +3,13 @@ import os import re import sys +from typing import List, Optional, Tuple +from ray.experimental.channel.gpu_communicator import ( + GPUCommunicator, + TorchTensorAllocator, +) +from ray.experimental.channel.nccl_group import _NcclGroup +import socket import torch import time @@ -33,6 +40,11 @@ class TorchTensorWorker: def __init__(self): self.device = torch_utils.get_devices()[0] + def init_distributed(self, world_size, rank): + torch.distributed.init_process_group( + backend="nccl", world_size=world_size, rank=rank + ) + def send(self, shape, dtype, value: int, send_tensor=True): if not send_tensor: return 1 @@ -291,6 +303,316 @@ def test_torch_tensor_nccl_dynamic(ray_start_regular): compiled_dag.teardown() +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_custom_comm(ray_start_regular): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + sender = actor_cls.remote() + receiver = actor_cls.remote() + + class TestNcclGroup(GPUCommunicator): + """ + A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. + """ + + def __init__(self, world_size, comm_id, actor_handles): + self._world_size = world_size + self._comm_id = comm_id + self._actor_handles = actor_handles + self._inner = None + + def initialize(self, rank: int) -> None: + self._inner = _NcclGroup( + self._world_size, + self._comm_id, + rank, + self._actor_handles, + torch.cuda.current_stream().cuda_stream, + ) + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the NCCL group.") + return rank + + def get_world_size(self) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + return self._world_size + + def get_self_rank(self) -> Optional[int]: + if self._inner is None: + return None + return self._inner.get_self_rank() + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + return self._inner.send(value, peer_rank) + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) + + def destroy(self) -> None: + return self._inner.destroy() + + from cupy.cuda import nccl + + comm_id = nccl.get_unique_id() + nccl_group = TestNcclGroup(2, comm_id, [sender, receiver]) + with InputNode() as inp: + dag = sender.send_with_tuple_args.bind(inp) + dag = dag.with_type_hint(TorchTensorType(transport=nccl_group)) + dag = receiver.recv.bind(dag) + + compiled_dag = dag.experimental_compile() + for i in range(3): + i += 1 + shape = (i * 10,) + dtype = torch.float16 + args = (shape, dtype, i) + ref = compiled_dag.execute(args) + result = ray.get(ref) + assert result == (i, shape, dtype) + + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_custom_comm_invalid(ray_start_regular): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + actor1 = actor_cls.remote() + actor2 = actor_cls.remote() + + class MockNcclGroup(GPUCommunicator): + """ + A mock NCCL group for testing. Send and recv are not implemented. + """ + + def __init__(self, world_size, actor_handles): + self._world_size = world_size + self._actor_handles = actor_handles + self._rank = None + + def initialize(self, rank: int) -> None: + expected_rank = self.get_rank(ray.get_runtime_context().current_actor) + assert ( + rank == expected_rank + ), f"NCCL actor's rank {rank} does not match expected rank {expected_rank}" + self._rank = rank + self._device = torch_utils.get_devices()[0] + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the NCCL group.") + return rank + + def get_world_size(self) -> int: + return self._world_size + + def get_self_rank(self) -> Optional[int]: + return self._rank + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + return None + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + return None + + def destroy(self) -> None: + pass + + nccl_group = MockNcclGroup(2, [actor1, actor2]) + + # Mixed usage of NCCL groups should throw an error + # Case 1: custom NCCL group first, then default NCCL group + with InputNode() as inp: + dag = actor1.send_with_tuple_args.bind(inp) + dag = dag.with_type_hint(TorchTensorType(transport=nccl_group)) + dag = actor2.recv.bind(dag) + dag = actor2.send_with_tuple_args.bind(dag) + dag = dag.with_type_hint(TorchTensorType(transport="nccl")) + dag = actor1.recv.bind(dag) + with pytest.raises( + ValueError, + match=r"Accelerated DAGs do not support mixed usage of type hints.*", + ): + dag.experimental_compile() + + # Case 2: default NCCL group first, then custom NCCL group + with InputNode() as inp: + dag = actor1.send_with_tuple_args.bind(inp) + dag = dag.with_type_hint(TorchTensorType(transport="nccl")) + dag = actor2.recv.bind(dag) + dag = actor2.send_with_tuple_args.bind(dag) + dag = dag.with_type_hint(TorchTensorType(transport=nccl_group)) + dag = actor1.recv.bind(dag) + with pytest.raises( + ValueError, + match=r"Accelerated DAGs do not support mixed usage of type hints.*", + ): + dag.experimental_compile() + + nccl_group2 = MockNcclGroup(2, [actor1, actor2]) + + # Using two different custom NCCL groups are currently not supported + with InputNode() as inp: + dag = actor1.send_with_tuple_args.bind(inp) + dag = dag.with_type_hint(TorchTensorType(transport=nccl_group)) + dag = actor2.recv.bind(dag) + dag = actor2.send_with_tuple_args.bind(dag) + dag = dag.with_type_hint(TorchTensorType(transport=nccl_group2)) + dag = actor1.recv.bind(dag) + with pytest.raises( + ValueError, + match=( + "Accelerated DAGs currently only support " + "a single custom NCCL group, but multiple " + "have been specified." + ), + ): + dag.experimental_compile() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_custom_comm_inited(ray_start_regular): + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + runtime_env = { + "env_vars": { + "MASTER_ADDR": socket.gethostbyname(socket.gethostname()), + "MASTER_PORT": "8888", + } + } + actor_cls = TorchTensorWorker.options( + num_cpus=0, num_gpus=1, runtime_env=runtime_env + ) + + sender = actor_cls.remote() + receiver = actor_cls.remote() + + # Simulates that the distributed environment (e.g., torch.distributed) + # have already been set up + refs = [ + sender.init_distributed.remote(2, 0), + receiver.init_distributed.remote(2, 1), + ] + ray.wait(refs) + + class InitedNcclGroup(GPUCommunicator): + """ + A custom NCCL group based on existing torch.distributed setup. + """ + + def __init__(self, world_size, actor_handles): + self._world_size = world_size + self._actor_handles = actor_handles + self._rank = None + + def initialize(self, rank: int) -> None: + expected_rank = self.get_rank(ray.get_runtime_context().current_actor) + assert ( + rank == expected_rank + ), f"NCCL actor's rank {rank} does not match expected rank {expected_rank}" + self._rank = rank + self._device = torch_utils.get_devices()[0] + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the NCCL group.") + return rank + + def get_world_size(self) -> int: + return self._world_size + + def get_self_rank(self) -> Optional[int]: + return self._rank + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + torch.distributed.send(value, peer_rank) + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + tensor = torch.empty(torch.Size(shape), dtype=dtype, device=self._device) + torch.distributed.recv(tensor, peer_rank) + return tensor + + def destroy(self) -> None: + pass + + nccl_group = InitedNcclGroup(2, [sender, receiver]) + with InputNode() as inp: + dag = sender.send_with_tuple_args.bind(inp) + dag = dag.with_type_hint(TorchTensorType(transport=nccl_group)) + dag = receiver.recv.bind(dag) + + compiled_dag = dag.experimental_compile() + for i in range(3): + i += 1 + shape = (i * 10,) + dtype = torch.float16 + args = (shape, dtype, i) + ref = compiled_dag.execute(args) + result = ray.get(ref) + assert result == (i, shape, dtype) + + compiled_dag.teardown() + + @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_nccl_wrong_shape(ray_start_regular): if not USE_GPU: diff --git a/python/ray/experimental/channel/__init__.py b/python/ray/experimental/channel/__init__.py index bcca146bd8f6..03e3be2d59e1 100644 --- a/python/ray/experimental/channel/__init__.py +++ b/python/ray/experimental/channel/__init__.py @@ -10,6 +10,7 @@ SynchronousWriter, WriterInterface, ) +from ray.experimental.channel.gpu_communicator import GPUCommunicator from ray.experimental.channel.intra_process_channel import IntraProcessChannel from ray.experimental.channel.shared_memory_channel import Channel, CompositeChannel from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel @@ -19,6 +20,7 @@ "AwaitableBackgroundWriter", "CachedChannel", "Channel", + "GPUCommunicator", "ReaderInterface", "SynchronousReader", "SynchronousWriter", diff --git a/python/ray/experimental/channel/common.py b/python/ray/experimental/channel/common.py index f400788f3398..928d4e2a339d 100644 --- a/python/ray/experimental/channel/common.py +++ b/python/ray/experimental/channel/common.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import ray -from ray.experimental.channel.nccl_group import _NcclGroup +from ray.experimental.channel.gpu_communicator import GPUCommunicator from ray.experimental.channel.serialization_context import _SerializationContext from ray.util.annotations import DeveloperAPI, PublicAPI @@ -100,6 +100,14 @@ def requires_nccl(self) -> bool: # By default, channels do not require NCCL. return False + def get_custom_nccl_group(self) -> Optional[GPUCommunicator]: + """ + Return the custom NCCL group if one is specified. + """ + if self._contains_type is not None: + return self._contains_type.get_custom_nccl_group() + return None + def set_nccl_group_id(self, group_id: str) -> None: raise NotImplementedError @@ -112,7 +120,7 @@ class ChannelContext: def __init__(self): # Used for the torch.Tensor NCCL transport. - self.nccl_groups: Dict[str, "_NcclGroup"] = {} + self.nccl_groups: Dict[str, "GPUCommunicator"] = {} @staticmethod def get_current() -> "ChannelContext": diff --git a/python/ray/experimental/channel/conftest.py b/python/ray/experimental/channel/conftest.py index 6f53f95c90b3..8886a2cfecd0 100644 --- a/python/ray/experimental/channel/conftest.py +++ b/python/ray/experimental/channel/conftest.py @@ -1,11 +1,13 @@ import asyncio from collections import defaultdict +from typing import Optional, Tuple from unittest import mock import torch import ray import ray.experimental.channel as ray_channel +from ray.experimental.channel.gpu_communicator import TorchTensorAllocator @ray.remote(num_cpus=0) @@ -74,13 +76,24 @@ def send(self, tensor: torch.Tensor, peer_rank: int): ray.get(barrier.wait.remote(self.num_ops[barrier_key], tensor)) self.num_ops[barrier_key] += 1 - def recv(self, buf: torch.Tensor, peer_rank: int): + def recv( + self, + shape: Tuple[int], + dtype: torch.dtype, + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ): # "Receive" the tensor from the barrier actor. barrier_key = f"barrier-{peer_rank}-{self.get_self_rank()}" barrier = ray.get_actor(name=barrier_key) received_tensor = ray.get(barrier.wait.remote(self.num_ops[barrier_key])) + assert ( + allocator is not None + ), "torch tensor allocator is required for MockNcclGroup" + buf = allocator(shape, dtype) buf[:] = received_tensor[:] self.num_ops[barrier_key] += 1 + return buf def start_nccl_mock(): diff --git a/python/ray/experimental/channel/gpu_communicator.py b/python/ray/experimental/channel/gpu_communicator.py new file mode 100644 index 000000000000..e6bc2fccdb2d --- /dev/null +++ b/python/ray/experimental/channel/gpu_communicator.py @@ -0,0 +1,115 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple + +import ray +from ray.util.annotations import DeveloperAPI + +if TYPE_CHECKING: + import torch + + +# Signature for a torch.Tensor allocator is: +# (shape: Tuple[int], dtype: torch.dtype) -> torch.Tensor. +TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"] + + +@DeveloperAPI +class GPUCommunicator(ABC): + """ + Communicator for a group of aDAG actors on Nvidia GPU. + + The aDAG execution leverages this internally to support communication + between actors in the group. + """ + + @abstractmethod + def initialize(self, rank: int) -> None: + """ + Initialize the communicator from the actor. + + This is called once by aDAG on each actor to initialize the communicator, + before any other methods. + + Args: + rank: The rank of this actor in the group. + """ + raise NotImplementedError + + @abstractmethod + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + """ + Get handles of all actors for this communicator group. + """ + raise NotImplementedError + + @abstractmethod + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + """ + Return the given actor's rank in the group. + + Args: + actor: The actor handle to look up. + """ + raise NotImplementedError + + @abstractmethod + def get_self_rank(self) -> Optional[int]: + """ + Return this actor's rank. + """ + raise NotImplementedError + + def get_world_size(self) -> int: + """ + Return the number of ranks in the group. + """ + raise NotImplementedError + + @abstractmethod + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + """ + Send a torch.Tensor to a peer. + + This returns when the send kernel has been queued, but the kernel may + not have completed. Therefore, the caller should ensure that there are + no concurrent writes to the sent `value` until the send has finished. + + Args: + value: The torch.Tensor to send. It should already be on this + actor's default device. + peer_rank: The rank of the actor to send to. + """ + raise NotImplementedError + + @abstractmethod + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + """ + Receive a torch.Tensor from a peer and synchronize. + + After this call returns, the receive buffer is safe to read from from + any stream. An RayChannelError will be raised if an error occurred (e.g., + remote actor died), and the buffer is not safe to read. + + Args: + shape: The shape of the tensor to receive. + dtype: The dtype of the tensor to receive. + peer_rank: The rank of the actor to receive from. + allocator: A function to allocate the tensor to receive into. + """ + raise NotImplementedError + + @abstractmethod + def destroy() -> None: + """ + Destroy the GPU communicator. + + Any destruction and cleanup for the GPU communicator should be + done here. Implement as a noop is nothing is needed. + """ + raise NotImplementedError diff --git a/python/ray/experimental/channel/nccl_group.py b/python/ray/experimental/channel/nccl_group.py index 84200cf87f0a..753c05ed1d74 100644 --- a/python/ray/experimental/channel/nccl_group.py +++ b/python/ray/experimental/channel/nccl_group.py @@ -1,9 +1,13 @@ import logging from types import ModuleType -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple import ray from ray.exceptions import RayChannelError +from ray.experimental.channel.gpu_communicator import ( + GPUCommunicator, + TorchTensorAllocator, +) if TYPE_CHECKING: import cupy as cp @@ -16,9 +20,10 @@ logger = logging.getLogger(__name__) -class _NcclGroup: +class _NcclGroup(GPUCommunicator): """ - Represents an actor's NCCL communicator. + Represents an actor's NCCL communicator. This is the default NCCL communicator + to be used in aDAG if a custom communicator is not provided. This class is not thread-safe. """ @@ -62,6 +67,7 @@ def __init__( cuda_stream: A raw CUDA stream to dispatch NCCL ops to. If rank is specified, then this must be specified too. """ + self._world_size = world_size self._rank: Optional[int] = rank self.nccl_util: Optional[ModuleType] = None self._actor_handles = actor_handles @@ -100,7 +106,11 @@ def __init__( self._closed = False - def _get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + def initialize(self, rank: int) -> None: + # No additional initialization is needed. + pass + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: return self._actor_handles def get_rank(self, actor: ray.actor.ActorHandle) -> int: @@ -123,7 +133,13 @@ def get_self_rank(self) -> Optional[int]: """ return self._rank - def send(self, value: "torch.Tensor", peer_rank: int): + def get_world_size(self) -> int: + """ + Return the number of ranks in the NCCL communicator. + """ + return self._world_size + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: """ Send a torch.Tensor to a peer. @@ -151,7 +167,13 @@ def send(self, value: "torch.Tensor", peer_rank: int): self._cuda_stream.ptr, ) - def recv(self, buf: "torch.Tensor", peer_rank: int): + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator=Optional[TorchTensorAllocator], + ) -> "torch.Tensor": """ Receive a torch.Tensor from a peer and synchronize the current stream. @@ -165,6 +187,8 @@ def recv(self, buf: "torch.Tensor", peer_rank: int): """ if self._closed: raise RayChannelError("NCCL group has been destroyed.") + assert allocator is not None, "NCCL group requires a tensor allocator" + buf = allocator(shape, dtype) self._comm.recv( self.nccl_util.get_tensor_ptr(buf), buf.numel(), @@ -180,8 +204,9 @@ def recv(self, buf: "torch.Tensor", peer_rank: int): self._cuda_stream.synchronize() if self._closed: raise RayChannelError("NCCL group has been destroyed.") + return buf - def destroy(self): + def destroy(self) -> None: """ Destroy the NCCL group. """ diff --git a/python/ray/experimental/channel/torch_tensor_nccl_channel.py b/python/ray/experimental/channel/torch_tensor_nccl_channel.py index b4de7b1b862c..ad95c69c7c84 100644 --- a/python/ray/experimental/channel/torch_tensor_nccl_channel.py +++ b/python/ray/experimental/channel/torch_tensor_nccl_channel.py @@ -2,12 +2,16 @@ import logging import uuid from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import ray import ray.util.serialization from ray.experimental.channel import ChannelContext from ray.experimental.channel.common import ChannelInterface +from ray.experimental.channel.gpu_communicator import ( + GPUCommunicator, + TorchTensorAllocator, +) from ray.experimental.channel.nccl_group import _NcclGroup from ray.experimental.channel.shared_memory_channel import SharedMemoryType from ray.experimental.channel.torch_tensor_type import TENSOR_METADATA_SIZE_BYTES @@ -26,11 +30,6 @@ logger = logging.getLogger(__name__) -# Signature for a torch.Tensor allocator is: -# (shape: Tuple[int], dtype: torch.dtype) -> torch.Tensor. -TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"] - - class NestedTorchTensorNcclChannel(ChannelInterface): def __init__( self, @@ -236,7 +235,7 @@ def __init__( ctx = ChannelContext.get_current() assert self._typ.nccl_group_id is not None, "No NCCL group specified." self._nccl_group_id: str = self._typ.nccl_group_id - self._nccl_group: "_NcclGroup" = ctx.nccl_groups[self._typ.nccl_group_id] + self._nccl_group: "GPUCommunicator" = ctx.nccl_groups[self._typ.nccl_group_id] assert ( self._nccl_group is not None ), "ChannelContext.nccl_group is not initialized." @@ -379,11 +378,6 @@ def write( for rank in self._reader_ranks: self._nccl_group.send(tensor, rank) - def _read_single_tensor(self, typ: "TorchTensorType") -> "torch.Tensor": - buf = self._torch_tensor_allocator(typ._shape, typ._dtype) - self._nccl_group.recv(buf, self._writer_rank) - return buf - def read( self, timeout: Optional[float] = None ) -> Union["torch.Tensor", List["torch.Tensor"]]: @@ -393,11 +387,22 @@ def read( meta = self._typ if not isinstance(meta, list): - return self._read_single_tensor(meta) + return self._nccl_group.recv( + meta._shape, + meta._dtype, + self._writer_rank, + self._torch_tensor_allocator, + ) bufs: List["torch.Tensor"] = [] for typ in meta: - bufs.append(self._read_single_tensor(typ)) + buf = self._nccl_group.recv( + typ._shape, + typ._dtype, + self._writer_rank, + self._torch_tensor_allocator, + ) + bufs.append(buf) # TODO: Sync CUDA stream after receiving all tensors, instead of after # each tensor. return bufs @@ -420,7 +425,15 @@ def has_static_type(self) -> bool: ) -def _do_init_nccl_group(self, group_id, world_size, comm_id, rank, actor_handles): +def _do_init_nccl_group( + self, + group_id, + world_size, + comm_id, + rank, + actor_handles, + custom_nccl_group: Optional[GPUCommunicator] = None, +): import torch assert ( @@ -428,13 +441,17 @@ def _do_init_nccl_group(self, group_id, world_size, comm_id, rank, actor_handles ), "Actors participating in NCCL group must have at least one GPU assigned" ctx = ChannelContext.get_current() - ctx.nccl_groups[group_id] = _NcclGroup( - world_size, - comm_id, - rank, - actor_handles, - torch.cuda.current_stream().cuda_stream, - ) + if custom_nccl_group is not None: + custom_nccl_group.initialize(rank) + ctx.nccl_groups[group_id] = custom_nccl_group + else: + ctx.nccl_groups[group_id] = _NcclGroup( + world_size, + comm_id, + rank, + actor_handles, + torch.cuda.current_stream().cuda_stream, + ) def _do_destroy_nccl_group(self, group_id): @@ -456,9 +473,50 @@ def _do_get_unique_nccl_id(self) -> bool: return nccl.get_unique_id() +def _get_ranks( + actors: List[ray.actor.ActorHandle], custom_nccl_group: Optional[GPUCommunicator] +) -> List[int]: + """ + Get sorted ranks for the NCCL group to use. If custom_nccl_group is specified, + return all ranks from it, otherwise, return list(range(len(actors))). + + Args: + actors: A list of actors that participate in the NCCL group. + custom_nccl_group: The custom NCCL group to use. + """ + if custom_nccl_group is None: + return list(range(len(actors))) + + assert len(actors) == custom_nccl_group.get_world_size(), ( + "The world size of the custom NCCL group does not match the number " + "of actors." + ) + ranks = set() + for actor in actors: + rank = custom_nccl_group.get_rank(actor) + assert rank not in ranks, "Duplicate rank in custom NCCL group" + ranks.add(rank) + assert custom_nccl_group.get_world_size() == len(actors), ( + "The world size of the custom NCCL group " + f"({custom_nccl_group.get_world_size()}) " + "does not match the number of actors " + f"({len(actors)})." + ) + return sorted(ranks) + + def _init_nccl_group( actors: List[ray.actor.ActorHandle], + custom_nccl_group: Optional[GPUCommunicator] = None, ) -> str: + """ + Initialize a NCCL group with the given actors. If a custom NCCL group is + provided, then it will be used, otherwise a new NCCL group will be created. + + Args: + actors: A list of actors that participate in the NCCL group. + custom_nccl_group: A custom NCCL group to initialize. + """ ctx = ChannelContext.get_current() has_gpus = ray.get( @@ -468,8 +526,9 @@ def _init_nccl_group( if not has_gpu: raise ValueError( f"Actor {actor} returns a tensor with type hint " - 'TorchTensor(transport="nccl") but actor does not have a ' - "GPU assigned by Ray." + 'TorchTensor(transport="nccl") or ' + "TorchTensor(transport=nccl_group_handle)" + "but actor does not have a GPU assigned by Ray." ) actor_ids = {actor._ray_actor_id for actor in actors} @@ -482,9 +541,13 @@ def _init_nccl_group( # Used to uniquely identify this NCCL group. group_id = str(uuid.uuid4()) - logger.info(f"Creating NCCL group {group_id} on actors: {actors}") + if custom_nccl_group is not None: + logger.info(f"Initializing custom NCCL group {group_id} on actors: {actors}") + else: + logger.info(f"Creating NCCL group {group_id} on actors: {actors}") world_size = len(actors) + ranks = _get_ranks(actors, custom_nccl_group) init_tasks = [ actor.__ray_call__.remote( _do_init_nccl_group, @@ -493,8 +556,9 @@ def _init_nccl_group( nccl_comm_id, rank, actors, + custom_nccl_group, ) - for rank, actor in enumerate(actors) + for rank, actor in zip(ranks, actors) ] try: ray.get(init_tasks, timeout=30) @@ -504,25 +568,31 @@ def _init_nccl_group( ) ray.get(init_tasks) - logger.info("NCCL group created.") + logger.info("NCCL group initialized.") - ctx.nccl_groups[group_id] = _NcclGroup( - world_size, - nccl_comm_id, - rank=None, - actor_handles=actors, - cuda_stream=None, - ) + if custom_nccl_group is not None: + ctx.nccl_groups[group_id] = custom_nccl_group + else: + ctx.nccl_groups[group_id] = _NcclGroup( + world_size, + nccl_comm_id, + rank=None, + actor_handles=actors, + cuda_stream=None, + ) return group_id def _destroy_nccl_group(group_id: str) -> None: + """ + Destroy the NCCL group with the given ID. + """ ctx = ChannelContext.get_current() if group_id not in ctx.nccl_groups: return group = ctx.nccl_groups[group_id] - actors = group._get_actor_handles() + actors = group.get_actor_handles() destroy_tasks = [ actor.__ray_call__.remote( _do_destroy_nccl_group, diff --git a/python/ray/experimental/channel/torch_tensor_type.py b/python/ray/experimental/channel/torch_tensor_type.py index f573d4bc9931..c37977728b43 100644 --- a/python/ray/experimental/channel/torch_tensor_type.py +++ b/python/ray/experimental/channel/torch_tensor_type.py @@ -3,13 +3,15 @@ import ray from ray.experimental.channel import ChannelContext, ChannelOutputType +from ray.experimental.channel.gpu_communicator import ( + GPUCommunicator, + TorchTensorAllocator, +) from ray.util.annotations import PublicAPI if TYPE_CHECKING: import torch - from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorAllocator - logger = logging.getLogger(__name__) # 100KB to store metadata and/or exceptions. @@ -28,7 +30,7 @@ def __init__( self, _shape: Union[int, Tuple[int], str] = AUTO, _dtype: "torch.dtype" = AUTO, - transport: Optional[str] = AUTO, + transport: Optional[Union[str, GPUCommunicator]] = AUTO, _direct_return: Optional[bool] = False, ): """ @@ -73,6 +75,11 @@ def __init__( self._dtype = _dtype self._direct_return = _direct_return + self._custom_nccl_group: Optional[GPUCommunicator] = None + if isinstance(transport, GPUCommunicator): + self._custom_nccl_group = transport + transport = self.NCCL + if transport not in [self.AUTO, self.NCCL]: raise ValueError( "`transport` must be TorchTensorType.AUTO or TorchTensorType.NCCL" @@ -170,6 +177,12 @@ def create_channel( def requires_nccl(self) -> bool: return self.transport == self.NCCL + def get_custom_nccl_group(self) -> Optional[GPUCommunicator]: + """ + Return the custom NCCL group if one is specified. + """ + return self._custom_nccl_group + def set_nccl_group_id(self, group_id: str) -> None: self._nccl_group_id = group_id