Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vLLM single and multi-host TPUs on GKE #7613

Merged
merged 20 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
ray[default,serve]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove Ray Serve from the dependency. vLLM does not need to be used with Ray Serve, and we'd like to minimize the dependencies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we still install ray[default]? The reason for this is that the GCS endpoint needs to run through the Ray dashboard which does not get installed if you just do pip install ray. The endpoint is needed in order for other Ray workers to join the cluster for multihost inference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, how is ray[default] different from just ray?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://docs.ray.io/en/latest/ray-overview/installation.html, ray[default] includes the Ray dashboard while ray is just the Ray core libraries.

The reason for including the dashboard (in addition to debuggability) is that the GCS and Ray Job endpoints are exposed through the dashboard. So without it, the other Ray nodes are not able to join the cluster. For example the Kubernetes operator initializes Ray worker nodes by having them ping the GCS endpoint.

5 changes: 4 additions & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def __init__(
raise NotImplementedError("TPU version must be 4 or higher.")

self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
tpu_type = tpu_type.lower()

if "lite" not in tpu_type:
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
Expand Down
28 changes: 27 additions & 1 deletion vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand All @@ -24,9 +26,33 @@ def __init__(self, group: ProcessGroup):
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())

# Calculate how many TPU nodes are in the current placement group.
pg_table = ray.util.placement_group_table()
current_pg = ray.util.get_current_placement_group()

print(f"current pg: {current_pg.id.hex()}")
nodes_in_pg = set()
for pg_key, pg in pg_table.items():
if pg_key == current_pg.id.hex():
for _, node in pg['bundles_to_node_id'].items():
nodes_in_pg.add(node)
print(f"pg nodes: {nodes_in_pg}")
num_nodes = len(nodes_in_pg)

local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size

print(f"global_rank: {global_rank}")
print(f"global_world_size: {global_world_size}")
print(f"num_nodes: {num_nodes}")
print(f"local_world_size: {local_world_size}")
print(f"local_rank: {local_rank}")
richardsliu marked this conversation as resolved.
Show resolved Hide resolved

# Ensure environment variables are set for multihost deployments.
os.environ['CLOUD_TPU_TASK_ID'] = str(global_rank)
os.environ['TPU_VISIBLE_CHIPS'] = str(local_rank)

pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()

Expand Down
48 changes: 35 additions & 13 deletions vllm/executor/ray_tpu_executor.py
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"

# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})

worker = ray.remote(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: can we use the runtime_env arg in ray.remote instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think runtime_env will be overwritten when Ray starts up and tries to initialize the environment from within the Ray process. We are using this to manually override the environment after the Ray task starts up.

num_cpus=0,
resources={"TPU": 1},
Expand All @@ -80,6 +93,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
if override_env:
worker.override_env_vars.remote(override_env)

worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
Expand Down Expand Up @@ -118,8 +133,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for _ in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
self._run_workers(
"update_environment_variables",
all_args=all_args_to_update_environment_variables,
)

if len(node_workers) == 1:
# in single node case, we don't need to get the IP address.
Expand All @@ -145,9 +162,11 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)

self._run_workers("init_device")
self._run_workers("load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
)

def _driver_execute_model(
self,
Expand Down Expand Up @@ -190,10 +209,10 @@ def _run_workers(
"max_concurrent_workers is not supported yet.")

count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
all_worker_args = (repeat(args, count) if all_args is None else islice(
all_args, 1, None))
all_worker_kwargs = (repeat(kwargs, count) if all_kwargs is None else
islice(all_kwargs, 1, None))

# Start the ray workers first.
ray_worker_outputs = [
Expand Down Expand Up @@ -241,9 +260,11 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
self._run_workers(
"initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
)

def execute_model(
self,
Expand All @@ -253,7 +274,8 @@ def execute_model(
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
async_run_remote_workers_only=True,
**self.extra_execute_model_run_workers_kwargs)
**self.extra_execute_model_run_workers_kwargs,
)

# Only the driver worker returns the sampling results.
return self._driver_execute_model(execute_model_req)
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import List, Optional, Tuple, Union

from vllm.config import ParallelConfig
Expand Down Expand Up @@ -63,6 +64,9 @@ def execute_model_spmd(
return execute_model_req, output
return output

def override_env_vars(self, vars):
os.environ.update(vars)

ray_import_err = None

except ImportError as e:
Expand Down
Loading