11# SPDX-License-Identifier: Apache-2.0
22"""
3- a simple demonstration to show how to control
4- the placement of the vLLM workers with Ray.
5- The key is to set VLLM_RAY_PER_WORKER_GPUS and
6- VLLM_RAY_BUNDLE_INDICES properly.
3+ a simple demonstration to show how to co-locate
4+ vLLM worker with training actors on the same GPUs,
5+ for RLHF-like applications.
6+ The key points:
7+ - Control the placement of the vLLM workers with Ray, by setting
8+ VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
9+ - Use cuda-ipc to pass tensors, since NCCL does not work when we have
10+ multiple processes on the same GPU.
711"""
812import os
913
1014import ray
15+ import torch
1116from ray .util .placement_group import placement_group
1217from ray .util .scheduling_strategies import PlacementGroupSchedulingStrategy
1318
@@ -19,7 +24,33 @@ class MyWorker(Worker):
1924
2025 def report_device_id (self ) -> str :
2126 from vllm .platforms import current_platform
22- return current_platform .get_device_uuid (self .device .index )
27+ self .device_uuid = current_platform .get_device_uuid (self .device .index )
28+ return self .device_uuid
29+
30+ def update_weights_from_ipc_handles (self , ipc_handles ):
31+ handles = ipc_handles [self .device_uuid ]
32+ device_id = self .device .index
33+ weights = []
34+ for name , handle in handles .items ():
35+ func , args = handle
36+ list_args = list (args )
37+ # the key is to change device id to the current device id
38+ # in case two processes have different CUDA_VISIBLE_DEVICES
39+ list_args [6 ] = device_id
40+ tensor = func (* list_args )
41+ weights .append ((name , tensor ))
42+ self .model_runner .model .load_weights (weights = weights )
43+ torch .cuda .synchronize ()
44+
45+ def check_weights_changed (self ):
46+ """
47+ Check if the weights are updated to 0.
48+ """
49+ weights_updated = True
50+ for name , p in self .model_runner .model .named_parameters ():
51+ weights_updated = weights_updated and torch .allclose (
52+ p , torch .zeros_like (p ))
53+ return weights_updated
2354
2455
2556class MyLLM (LLM ):
@@ -40,12 +71,32 @@ def __init__(self, *args, bundle_indices: list, **kwargs):
4071
4172class RayTrainingActor :
4273
43- def report_device_id (self ) -> str :
74+ def __init__ (self ):
75+ # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
76+ from transformers import AutoModelForCausalLM
77+ self .model = AutoModelForCausalLM .from_pretrained ("facebook/opt-125m" )
78+ self .model .to ("cuda:0" )
79+ for name , p in self .model .named_parameters ():
80+ p .data .zero_ ()
81+ torch .cuda .synchronize ()
4482 # the argument for get_device_uuid is the index
4583 # of the GPU in the visible devices.
46- # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
4784 from vllm .platforms import current_platform
48- return current_platform .get_device_uuid (0 )
85+ self .device_uuid = current_platform .get_device_uuid (0 )
86+
87+ def report_device_id (self ) -> str :
88+ return self .device_uuid
89+
90+ def get_weight_ipc_handles (self ):
91+ from torch .multiprocessing .reductions import reduce_tensor
92+ data = {}
93+ for name , p in self .model .named_parameters ():
94+ # the training actor might only have a subset of the weights
95+ # and need to all-gather the weights from all the actors.
96+ # for demonstration, here we assume all training actors have
97+ # the full weights.
98+ data [name ] = reduce_tensor (p .detach ())
99+ return {self .device_uuid : data }
49100
50101
51102# ray manages 4 GPUs
@@ -78,6 +129,8 @@ def report_device_id(self) -> str:
78129 ),
79130 )(RayTrainingActor ).remote ()
80131 training_actors .append (training_actor )
132+
133+ for bundle_index , training_actor in enumerate (training_actors ):
81134 device_id = ray .get (training_actor .report_device_id .remote ())
82135 print (f"training actor { bundle_index } is on { device_id } " )
83136 training_actor_device_ids .append (device_id )
@@ -119,3 +172,18 @@ def report_device_id(self) -> str:
119172# the last two training actors should be
120173# on the same GPUs as the second inference engine
121174assert training_actor_device_ids [2 :] == inference_engine_device_ids [1 ]
175+
176+ print ("gather all the IPC handles from the training actors" )
177+ ipc_handles = {}
178+ for actor in training_actors :
179+ ipc_handles .update (ray .get (actor .get_weight_ipc_handles .remote ()))
180+
181+ print ("update the weights of the inference engines" )
182+ for llm in inference_engines :
183+ ray .get (
184+ llm .collective_rpc .remote ("update_weights_from_ipc_handles" ,
185+ args = (ipc_handles , )))
186+ print ("check if the weights are updated" )
187+ for llm in inference_engines :
188+ assert ray .get (
189+ llm .collective_rpc .remote ("check_weights_changed" , args = tuple ()))
0 commit comments