From 6d10f6c14dabdf1bf32d51c5c7dbe5184bec1274 Mon Sep 17 00:00:00 2001 From: Richard Liu <39319471+richardsliu@users.noreply.github.com> Date: Fri, 30 Aug 2024 00:27:40 -0700 Subject: [PATCH] [TPU] Support single and multi-host TPUs on GKE (#7613) Signed-off-by: Alvant --- requirements-tpu.txt | 2 +- vllm/attention/backends/pallas.py | 5 +++- .../device_communicators/tpu_communicator.py | 27 +++++++++++++++-- vllm/executor/ray_tpu_executor.py | 15 ++++++++++ vllm/executor/ray_utils.py | 29 +++++++++++++++++++ 5 files changed, 74 insertions(+), 4 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 5eb27b39eb623..4c606cf0a9105 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,4 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. -ray +ray[default] diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index ac03b6d8b1ead..c324d62d44d79 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -123,7 +123,10 @@ def __init__( raise NotImplementedError("TPU version must be 4 or higher.") self.megacore_mode = None - tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE") + tpu_type = tpu_type.lower() + if "lite" not in tpu_type: if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 81a141e86206a..765a0f9cb1c87 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,3 +1,5 @@ +import os + import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -5,11 +7,12 @@ from vllm.platforms import current_platform if current_platform.is_tpu(): - import ray import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt + from vllm.executor import ray_utils + class TpuCommunicator: @@ -24,9 +27,29 @@ def __init__(self, group: ProcessGroup): # be simply calculated as follows. global_rank = dist.get_rank(group) global_world_size = dist.get_world_size(group) - num_nodes = len(ray.nodes()) + + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 2a1fd35b65797..8f867b1d647a5 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -71,6 +71,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_module_name = "vllm.worker.tpu_worker" worker_class_name = "TPUWorker" + # GKE does not fetch environment information from metadata server + # and instead sets these from within the Ray process. Therefore we + # need to override the Ray environment variables manually. + override_env = {} + if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: + override_env.update({ + "TPU_CHIPS_PER_HOST_BOUNDS": + os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] + }) + if "TPU_HOST_BOUNDS" in os.environ: + override_env.update( + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + worker = ray.remote( num_cpus=0, resources={"TPU": 1}, @@ -81,6 +94,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) + if override_env: + worker.override_env_vars.remote(override_env) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index bfdd0f5cf97b3..59e9854393b6b 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,3 +1,4 @@ +import os import time from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -84,6 +85,9 @@ def execute_model_spmd( return output + def override_env_vars(self, vars: Dict[str, str]): + os.environ.update(vars) + ray_import_err = None except ImportError as e: @@ -291,3 +295,28 @@ def initialize_ray_cluster( _verify_bundles(current_placement_group, parallel_config, device_str) # Set the placement group in the parallel config parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes