diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b6f461b76ed03..b3e30a0434423 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,15 +1,15 @@ import multiprocessing +import os import pytest import torch -import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetUniqueId) -from vllm.distributed.parallel_state import ( - ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, - init_distributed_environment, with_pynccl_for_all_reduce) +from vllm.distributed.communication_op import ( # noqa + graph_capture_mode, tensor_model_parallel_all_reduce) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.utils import update_environment_variables @@ -41,6 +41,9 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) init_distributed_environment() fn() @@ -49,11 +52,13 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - comm = NCCLCommunicator() - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - comm.all_reduce(tensor) + pynccl_comm = PyNcclCommunicator() + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() - assert result == comm.world_size + assert result == pynccl_comm.world_size @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -70,37 +75,35 @@ def multiple_tp_worker_fn(): torch.distributed.new_group(ranks=[2, 3], backend="gloo") ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] - comm = NCCLCommunicator(group=group, device=device) + pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - comm.all_reduce(tensor) - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 4 - else: - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 2 + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") def test_pynccl_multiple_tp(): # this tests pynccl for multiple tp groups, in a standalone way - # i.e. call `comm.all_reduce` directly + # i.e. call `pynccl_comm.all_reduce` directly distributed_run(multiple_tp_worker_fn, 4) @worker_fn_wrapper def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - torch.cuda.set_device(torch.distributed.get_rank()) ensure_model_parallel_initialized(2, 2) - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with with_pynccl_for_all_reduce(): + with graph_capture_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) @@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - comm = NCCLCommunicator() + pynccl_comm = PyNcclCommunicator() # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - comm.all_reduce(a) - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**0 + pynccl_comm.all_reduce(a) + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**1 + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph(): def test_ncclGetUniqueId(): - unique_id = ncclGetUniqueId() + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() # `list(unique_id.internal)` is something like this: # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 80d03129bdb9b..32ab5694e5390 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,4 +1,5 @@ from collections import namedtuple +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -8,7 +9,26 @@ get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - is_pynccl_enabled_for_all_reduce) + get_tp_pynccl_communicator) + + +@contextmanager +def graph_capture_mode(): + # In graph capture, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor size + # is too large, it will fallback to the next available option. + pynccl_comm = get_tp_pynccl_communicator() + assert pynccl_comm is not None + with pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()): + yield def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) @@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_pynccl_enabled_for_all_reduce(): - pynccl_utils.all_reduce(input_) + pynccl_comm = get_tp_pynccl_communicator() + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 758994352e3de..168d4cc2df8a6 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,26 +1,4 @@ -# This file is a pure Python wrapper for the NCCL library. -# The main purpose is to use NCCL combined with CUDA graph. -# Before writing this script, we tried the following approach: -# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself -# often gets stuck when initializing the NCCL communicator. -# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` -# contains many other potential cuda APIs, that are not allowed during -# capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually -# doable, but we often encounter issues related with nccl versions, and need -# to switch between different versions of NCCL. See -# https://github.com/NVIDIA/nccl/issues/1234 for more details. -# A C/C++ binding is not flexible enough to handle this. It requires -# recompilation of the code every time we want to switch between different -# versions. This current implementation, with a **pure** Python wrapper, is -# more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` -# variable in the code. - -import ctypes -import platform +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -28,217 +6,70 @@ import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger -from vllm.utils import find_nccl_library, nccl_integrity_check logger = init_logger(__name__) -so_file = find_nccl_library() - -try: - # load the library in another process. - # if it core dumps, it will not crash the current process - nccl_integrity_check(so_file) - nccl = ctypes.CDLL(so_file) -except Exception as e: - logger.error( - "Failed to load NCCL library from %s ." - "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise, the nccl library might not exist, be corrupted " - "or it does not support the current platform %s." - "One solution is to download libnccl2 version 2.18 from " - "https://developer.download.nvidia.com/compute/cuda/repos/ " - "and extract the libnccl.so.2 file. If you already have the " - "library, please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) - raise e - -# === export types and functions from nccl to Python === -# for the original nccl definition, please check -# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in - -ncclResult_t = ctypes.c_int - -_c_ncclGetErrorString = nccl.ncclGetErrorString -_c_ncclGetErrorString.restype = ctypes.c_char_p -_c_ncclGetErrorString.argtypes = [ncclResult_t] - - -def NCCL_CHECK(result: ncclResult_t) -> None: - if result != 0: - error_str = _c_ncclGetErrorString(result) - error_str = error_str.decode("utf-8") - raise RuntimeError(f"NCCL error: {error_str}") - - -# equivalent to c declaration: -# ncclResult_t ncclGetVersion(int *version); -_c_ncclGetVersion = nccl.ncclGetVersion -_c_ncclGetVersion.restype = ctypes.c_int -_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - - -def ncclGetVersion() -> str: - version = ctypes.c_int() - NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version))) - # something like 21903 --> "2.19.3" - version_str = str(version.value) - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - -class NcclUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -# equivalent to c declaration: -# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); -_c_ncclGetUniqueId = nccl.ncclGetUniqueId -_c_ncclGetUniqueId.restype = ctypes.c_int -_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] - - -def ncclGetUniqueId() -> NcclUniqueId: - unique_id = NcclUniqueId() - NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id))) - return unique_id - - -# equivalent to c declaration: -# ncclResult_t ncclCommInitRank( -# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); -# note that ncclComm_t is a pointer type, so the first argument -# is a pointer to a pointer -_c_ncclCommInitRank = nccl.ncclCommInitRank -_c_ncclCommInitRank.restype = ctypes.c_int -_c_ncclCommInitRank.argtypes = [ - ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int -] - -ncclDataType_t = ctypes.c_int - - -class ncclDataTypeEnum: - ncclInt8 = 0 - ncclChar = 0 - ncclUint8 = 1 - ncclInt32 = 2 - ncclInt = 2 - ncclUint32 = 3 - ncclInt64 = 4 - ncclUint64 = 5 - ncclFloat16 = 6 - ncclHalf = 6 - ncclFloat32 = 7 - ncclFloat = 7 - ncclFloat64 = 8 - ncclDouble = 8 - ncclBfloat16 = 9 - ncclNumTypes = 10 - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - if dtype == torch.int8: - return cls.ncclInt8 - if dtype == torch.uint8: - return cls.ncclUint8 - if dtype == torch.int32: - return cls.ncclInt32 - if dtype == torch.int64: - return cls.ncclInt64 - if dtype == torch.float16: - return cls.ncclFloat16 - if dtype == torch.float32: - return cls.ncclFloat32 - if dtype == torch.float64: - return cls.ncclFloat64 - if dtype == torch.bfloat16: - return cls.ncclBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -ncclRedOp_t = ctypes.c_int - - -class ncclRedOpTypeEnum: - ncclSum = 0 - ncclProd = 1 - ncclMax = 2 - ncclMin = 3 - ncclAvg = 4 - ncclNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> int: - if op == ReduceOp.SUM: - return cls.ncclSum - if op == ReduceOp.PRODUCT: - return cls.ncclProd - if op == ReduceOp.MAX: - return cls.ncclMax - if op == ReduceOp.MIN: - return cls.ncclMin - if op == ReduceOp.AVG: - return cls.ncclAvg - raise ValueError(f"Unsupported op: {op}") - - -# equivalent to c declaration: -# ncclResult_t ncclAllReduce( -# const void* sendbuff, void* recvbuff, size_t count, -# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, -# udaStream_t stream); -# note that cudaStream_t is a pointer type, so the last argument is a pointer -_c_ncclAllReduce = nccl.ncclAllReduce -_c_ncclAllReduce.restype = ctypes.c_int -_c_ncclAllReduce.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t, - ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p -] - -# be cautious! this is a collective call, it will block until all -# processes in the communicator have called this function. -# because Python object destruction can happen in random order, -# it is better not to call it at all. -# equivalent to c declaration: -# ncclResult_t ncclCommDestroy(ncclComm_t comm); -_c_ncclCommDestroy = nccl.ncclCommDestroy -_c_ncclCommDestroy.restype = ctypes.c_int -_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] - - -class NCCLCommunicator: +class PyNcclCommunicator: def __init__( self, group: Optional[ProcessGroup] = None, device: Optional[Union[int, str, torch.device]] = None, + library_path: Optional[str] = None, ): """ Args: group: the process group to work on. If None, it will use the default process group. - device: the device to bind the NCCLCommunicator to. If None, + device: the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ assert dist.is_initialized() group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "NCCLCommunicator should be attached to a non-NCCL group.") + "PyNcclCommunicator should be attached to a non-NCCL group.") self.group = group # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + if self.rank == 0: - self.unique_id = ncclGetUniqueId() + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() else: - self.unique_id = NcclUniqueId() + # construct an empty unique id + self.unique_id = ncclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) # arg `src` in `broadcast` is the global rank @@ -246,7 +77,6 @@ def __init__( byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte - self.comm = ctypes.c_void_p() if device is None: local_rank = get_local_rank() device = torch.device(f"cuda:{local_rank}") @@ -261,15 +91,25 @@ def __init__( # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): - NCCL_CHECK( - _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank)) + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) self.stream = torch.cuda.Stream() + # A small all_reduce for warmup. + self.all_reduce(torch.zeros(1, device=device)) + self.stream.synchronize() + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None): + if self.disabled: + return # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" @@ -278,10 +118,32 @@ def all_reduce(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - NCCL_CHECK( - _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream))) + self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py deleted file mode 100644 index 44e4f39217a41..0000000000000 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -import contextlib -from typing import Optional - -import torch -from torch.distributed import ProcessGroup, ReduceOp - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -try: - from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetVersion) -except Exception as e: - # in non-NVIDIA environments, we can't import the nccl module - # e.g. when running on machines with AMD GPUs - logger.info("Failed to import NCCL library: %s", e) - logger.info("It is expected if you are not running on NVIDIA GPUs.") - pass - -comm: Optional["NCCLCommunicator"] = None - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return comm is not None - - -@contextlib.contextmanager -def set_pynccl_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - try: - assert comm is not None - comm.stream = stream - yield - finally: - pass - - -def init_process_group(group: Optional[ProcessGroup] = None) -> None: - assert not is_initialized() - global comm - logger.info("vLLM is using nccl==%s", ncclGetVersion()) - comm = NCCLCommunicator(group=group) - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - assert comm is not None - comm.all_reduce(input_, op) - - -def destroy_process_group() -> None: - global comm - comm = None - - -def get_world_size() -> int: - """Returns the world size.""" - assert comm is not None - return comm.world_size - - -def get_nccl_backend() -> Optional["NCCLCommunicator"]: - return comm diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000000..43d85674b23d0 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,258 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library, nccl_integrity_check + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + # load the library in another process. + # if it core dumps, it will not crash the current process + nccl_integrity_check(so_file) + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "One solution is to download libnccl2 version 2.18 from " + "https://developer.download.nvidia.com/compute/cuda/repos/ " + "and extract the libnccl.so.2 file. If you already have the " + "library, please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be5bb4e857caf..5075da11bb1b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,10 +3,10 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" -import contextlib -from typing import Optional +from typing import List, Optional import torch +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger @@ -14,10 +14,11 @@ logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TP_DEVICE_GROUP = None -_TP_CPU_GROUP = None +_TP_DEVICE_GROUP: Optional[ProcessGroup] = None +_TP_CPU_GROUP: Optional[ProcessGroup] = None +_TP_PYNCCL_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None +_PP_DEVICE_GROUP: Optional[ProcessGroup] = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -41,11 +42,16 @@ # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS = None +_PP_GLOBAL_RANKS: Optional[List[int]] = None _LOCAL_RANK = -1 +def get_tp_pynccl_communicator(): + global _TP_PYNCCL_COMMUNICATOR + return _TP_PYNCCL_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK @@ -80,10 +86,20 @@ def init_distributed_environment( # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 - if local_rank == -1 and distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank global _LOCAL_RANK _LOCAL_RANK = local_rank + # A small all_reduce for warmup. + data = torch.zeros(1) + if torch.cuda.is_available(): + data = data.to(device=f"cuda:{local_rank}") + torch.distributed.all_reduce(data) def initialize_model_parallel( @@ -133,29 +149,36 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) group = torch.distributed.new_group(ranks, backend=backend) cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _TP_DEVICE_GROUP = group _TP_CPU_GROUP = cpu_group + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + # Build the pipeline model-parallel groups. - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + global _PP_DEVICE_GROUP + global _PP_GLOBAL_RANKS + assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks + _PP_DEVICE_GROUP = group + _PP_GLOBAL_RANKS = ranks def ensure_model_parallel_initialized( @@ -188,8 +211,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP_DEVICE_GROUP is not None - and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) def get_cpu_world_group(): @@ -214,9 +236,9 @@ def get_tensor_model_parallel_cpu_group(): def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + assert _PP_DEVICE_GROUP is not None, ( "pipeline model parallel group is not initialized") - return _PIPELINE_MODEL_PARALLEL_GROUP + return _PP_DEVICE_GROUP def get_tensor_model_parallel_world_size(): @@ -253,36 +275,36 @@ def get_tensor_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") - return _PIPELINE_GLOBAL_RANKS[0] + return _PP_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] + return _PP_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): """Return the global rank that precedes the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def destroy_model_parallel(): @@ -295,45 +317,12 @@ def destroy_model_parallel(): if _TP_CPU_GROUP: torch.distributed.destroy_process_group(_TP_CPU_GROUP) _TP_CPU_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - if _PIPELINE_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _PIPELINE_GLOBAL_RANKS - _PIPELINE_GLOBAL_RANKS = None - from vllm.distributed.device_communicators import pynccl_utils - - # Destroy the pynccl states if any. - pynccl_utils.destroy_process_group() - - -# Whether to use pynccl for nccl all reduce. -# We use pynccl for all reduce when using CUDA graph, because torch.distributed -# is not well supported by CUDA graph. -_ENABLE_PYNCCL_FOR_ALL_REDUCE = False - - -@contextlib.contextmanager -def with_pynccl_for_all_reduce(): - from vllm.distributed.device_communicators import pynccl_utils - """use pynccl instead of torch.distributed for all reduce""" - tp_size = get_tensor_model_parallel_world_size() - if tp_size == 1: - # No-op. - # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. - yield - else: - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - old = _ENABLE_PYNCCL_FOR_ALL_REDUCE - _ENABLE_PYNCCL_FOR_ALL_REDUCE = True - - stream = torch.cuda.current_stream() - with pynccl_utils.set_pynccl_stream(stream): - yield - _ENABLE_PYNCCL_FOR_ALL_REDUCE = old - - -def is_pynccl_enabled_for_all_reduce(): - """check if pynccl is enabled for all reduce""" - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - return _ENABLE_PYNCCL_FOR_ALL_REDUCE + global _TP_PYNCCL_COMMUNICATOR + _TP_PYNCCL_COMMUNICATOR = None + + global _PP_DEVICE_GROUP + if _PP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) + _PP_DEVICE_GROUP = None + global _PP_GLOBAL_RANKS + _PP_GLOBAL_RANKS = None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e582116297c..3fc76c6142165 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,3 @@ -import contextlib import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -12,9 +11,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce -from vllm.distributed.device_communicators import (custom_all_reduce, - pynccl_utils) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.communication_op import graph_capture_mode +from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -917,10 +916,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: Since it is used for decoding-only, it assumes there's only 1 token per sequence in the batch. """ - # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never - # deleted before the CUDA graphs. - self.pynccl_backend = pynccl_utils.get_nccl_backend() - assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " @@ -1046,7 +1041,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_pynccl(): + with graph_capture_mode(): self.model( input_ids, positions, @@ -1061,7 +1056,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with _maybe_pynccl(): + with graph_capture_mode(): hidden_states = self.model( input_ids, positions, @@ -1113,16 +1108,6 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -@contextlib.contextmanager -def _maybe_pynccl(): - if pynccl_utils.is_initialized( - ) and not custom_all_reduce.is_initialized(): - with with_pynccl_for_all_reduce(): - yield - else: - yield - - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 43f6b2b443b70..0ca9c2b64cf30 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,9 +11,7 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - get_tensor_model_parallel_cpu_group, init_distributed_environment) -from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -306,29 +304,10 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - if pynccl_utils.is_initialized(): - pynccl_world_size = pynccl_utils.get_world_size() - if pynccl_world_size != parallel_config.world_size: - raise RuntimeError( - "pynccl is already initialized but the pynccl world " - "size does not match parallel_config.world_size " - f"({pynccl_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1: - # NOTE(woosuk): We don't initialize pynccl process group when world size - # is 1. - # NOTE(kaichao): By default, pynccl is initialized for tp group. - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) - # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: init_custom_ar() - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype.