-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support vLLM single and multi-host TPUs on GKE #7613
Changes from 5 commits
704aa60
98bf723
44683c9
726490e
1930cc9
b0225af
852f9fb
dcb6095
00fb272
b7774b3
da66cf9
70b6d4a
1522ad7
1629b0b
5a387ec
fb02fe6
07cee90
70e6447
117fc12
1781e11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", | |
worker_module_name = "vllm.worker.tpu_worker" | ||
worker_class_name = "TPUWorker" | ||
|
||
# GKE does not fetch environment information from metadata server | ||
# and instead sets these from within the Ray process. Therefore we | ||
# need to override the Ray environment variables manually. | ||
override_env = {} | ||
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: | ||
override_env.update({ | ||
"TPU_CHIPS_PER_HOST_BOUNDS": | ||
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] | ||
}) | ||
if "TPU_HOST_BOUNDS" in os.environ: | ||
override_env.update( | ||
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) | ||
|
||
worker = ray.remote( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: can we use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
num_cpus=0, | ||
resources={"TPU": 1}, | ||
|
@@ -80,6 +93,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", | |
worker_class_name=worker_class_name, | ||
trust_remote_code=self.model_config.trust_remote_code, | ||
) | ||
if override_env: | ||
worker.override_env_vars.remote(override_env) | ||
|
||
worker_ip = ray.get(worker.get_node_ip.remote()) | ||
if worker_ip == driver_ip and self.driver_dummy_worker is None: | ||
|
@@ -118,8 +133,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", | |
"VLLM_TRACE_FUNCTION": | ||
str(envs.VLLM_TRACE_FUNCTION), | ||
}, ) for _ in worker_node_and_gpu_ids] | ||
self._run_workers("update_environment_variables", | ||
all_args=all_args_to_update_environment_variables) | ||
self._run_workers( | ||
"update_environment_variables", | ||
all_args=all_args_to_update_environment_variables, | ||
) | ||
|
||
if len(node_workers) == 1: | ||
# in single node case, we don't need to get the IP address. | ||
|
@@ -145,9 +162,11 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", | |
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) | ||
|
||
self._run_workers("init_device") | ||
self._run_workers("load_model", | ||
max_concurrent_workers=self.parallel_config. | ||
max_parallel_loading_workers) | ||
self._run_workers( | ||
"load_model", | ||
max_concurrent_workers=self.parallel_config. | ||
max_parallel_loading_workers, | ||
) | ||
|
||
def _driver_execute_model( | ||
self, | ||
|
@@ -190,10 +209,10 @@ def _run_workers( | |
"max_concurrent_workers is not supported yet.") | ||
|
||
count = len(self.workers) | ||
all_worker_args = repeat(args, count) if all_args is None \ | ||
else islice(all_args, 1, None) | ||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ | ||
else islice(all_kwargs, 1, None) | ||
all_worker_args = (repeat(args, count) if all_args is None else islice( | ||
all_args, 1, None)) | ||
all_worker_kwargs = (repeat(kwargs, count) if all_kwargs is None else | ||
islice(all_kwargs, 1, None)) | ||
|
||
# Start the ray workers first. | ||
ray_worker_outputs = [ | ||
|
@@ -241,9 +260,11 @@ def initialize_cache(self, num_gpu_blocks: int, | |
num_cpu_blocks) | ||
self.cache_config.num_gpu_blocks = num_gpu_blocks | ||
self.cache_config.num_cpu_blocks = num_cpu_blocks | ||
self._run_workers("initialize_cache", | ||
num_gpu_blocks=num_gpu_blocks, | ||
num_cpu_blocks=num_cpu_blocks) | ||
self._run_workers( | ||
"initialize_cache", | ||
num_gpu_blocks=num_gpu_blocks, | ||
num_cpu_blocks=num_cpu_blocks, | ||
) | ||
|
||
def execute_model( | ||
self, | ||
|
@@ -253,7 +274,8 @@ def execute_model( | |
self.parallel_worker_tasks = self._run_workers( | ||
"start_worker_execution_loop", | ||
async_run_remote_workers_only=True, | ||
**self.extra_execute_model_run_workers_kwargs) | ||
**self.extra_execute_model_run_workers_kwargs, | ||
) | ||
|
||
# Only the driver worker returns the sampling results. | ||
return self._driver_execute_model(execute_model_req) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove Ray Serve from the dependency. vLLM does not need to be used with Ray Serve, and we'd like to minimize the dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we still install ray[default]? The reason for this is that the GCS endpoint needs to run through the Ray dashboard which does not get installed if you just do
pip install ray
. The endpoint is needed in order for other Ray workers to join the cluster for multihost inference.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, how is
ray[default]
different from justray
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://docs.ray.io/en/latest/ray-overview/installation.html,
ray[default]
includes the Ray dashboard whileray
is just the Ray core libraries.The reason for including the dashboard (in addition to debuggability) is that the GCS and Ray Job endpoints are exposed through the dashboard. So without it, the other Ray nodes are not able to join the cluster. For example the Kubernetes operator initializes Ray worker nodes by having them ping the GCS endpoint.