diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 05cb1e0f6ef5..de66ceaeef6f 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -22,6 +22,8 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) if USE_RAY: from vllm.executor import ray_utils @@ -79,9 +81,12 @@ def __init__(self, pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() + self.groups = create_optimized_replica_groups() def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - return xm.all_reduce(xm.REDUCE_SUM, input_) + # TODO: Remove the groups specification after XLA compiler can support + # auto-reordering the ring order for all-reduce. + return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 514851694837..fa493fefb8f0 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -119,11 +119,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: if supports_custom_op(): + from vllm.platforms import current_platform direct_register_custom_op( op_name="all_reduce", op_func=all_reduce, mutates_args=[], fake_impl=all_reduce_fake, + dispatch_key=current_platform.dispatch_key, ) @@ -219,7 +221,8 @@ def __init__( self.cpu_group, 1 << 22, 6) from vllm.platforms import current_platform - self.use_custom_op_call = current_platform.is_cuda_alike() + self.use_custom_op_call = (current_platform.is_cuda_alike() + or current_platform.is_tpu()) @property def first_rank(self): diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 9add8cee02e5..bd24072f4c1a 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -84,6 +84,12 @@ def __init__( def init_device(self): os.environ["PJRT_DEVICE"] = "TPU" + # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D + # ring, the xla tpu compiler flag + # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to + # fix this. It will be removed after the bug in XLA compiler is fixed. + os.environ["LIBTPU_INIT_ARGS"] = ( + "--xla_tpu_force_1d_allreduce_at_chunk_count=1") torch.set_grad_enabled(False) torch.set_default_dtype(self.model_config.dtype)