diff --git a/skyrl-train/examples/async/async_trainer.py b/skyrl-train/examples/async/async_trainer.py index e4b5ae0f24..a65de80b00 100644 --- a/skyrl-train/examples/async/async_trainer.py +++ b/skyrl-train/examples/async/async_trainer.py @@ -4,7 +4,8 @@ from loguru import logger from skyrl_train.trainer import RayPPOTrainer from tqdm import tqdm -from skyrl_train.utils import Timer, normalize_advantages_dict +from skyrl_train.utils import Timer +from skyrl_train.utils.ppo_utils import normalize_advantages_dict from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.generators.base import GeneratorOutput from skyrl_train.utils.trainer_utils import ResumeMode diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 2fe93746b4..fdddf84757 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -115,6 +115,7 @@ trainer: type: null # filter, replace, or null max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only) + gradient_checkpointing: true gradient_checkpointing_use_reentrant: false @@ -158,6 +159,8 @@ generator: num_inference_engines: 1 backend: "vllm" weight_sync_backend: "nccl" + # if using cuda_ipc, we send in batches of this size in GB + weight_transfer_threshold_cuda_ipc_GB: 1.0 inference_engine_tensor_parallel_size: 4 n_samples_per_prompt: 5 async_engine: true diff --git a/skyrl-train/skyrl_train/inference_engines/base.py b/skyrl-train/skyrl_train/inference_engines/base.py index ff073a79ff..b01d574f70 100644 --- a/skyrl-train/skyrl_train/inference_engines/base.py +++ b/skyrl-train/skyrl_train/inference_engines/base.py @@ -20,11 +20,11 @@ class InferenceEngineOutput(TypedDict): response_logprobs: Optional[List[List[float]]] -class NamedWeightUpdateRequest(TypedDict): - name: str - dtype: str - shape: List[int] - extras: Optional[Dict[str, Any]] +class NamedWeightsUpdateRequest(TypedDict): + names: List[str] + dtypes: List[str] + shapes: List[List[int]] + extras: Optional[List[Dict[str, Any]]] class InferenceEngineInterface(ABC): @@ -48,7 +48,7 @@ async def init_weight_update_communicator( raise NotImplementedError() @abstractmethod - async def update_named_weight(self, request: NamedWeightUpdateRequest): + async def update_named_weights(self, request: NamedWeightsUpdateRequest): raise NotImplementedError() @abstractmethod diff --git a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py index 6cccf3f7ed..1d82d4695f 100644 --- a/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl-train/skyrl_train/inference_engines/inference_engine_client.py @@ -2,7 +2,7 @@ InferenceEngineInterface, InferenceEngineInput, InferenceEngineOutput, - NamedWeightUpdateRequest, + NamedWeightsUpdateRequest, ) import asyncio from typing import List, Any, Optional @@ -96,9 +96,8 @@ async def _generate_with_trajectory_routing(self, prompts, prompt_token_ids, tra add_resp_logprobs = True response_logprobs[original_idx] = result["response_logprobs"][local_idx] - # something went wrong - if any([len(response) == 0 for response in responses]) or not all( - [isinstance(sample_ids, list) for sample_ids in response_ids] + if any([len(response) == 0 for response in responses]) or ( + add_resp_ids and not all([isinstance(sample_ids, list) for sample_ids in response_ids]) ): raise RuntimeError( "Did not receive responses / response ids for some prompts. This should never happen. There is likely something wrong with the inference engine" @@ -192,8 +191,8 @@ async def init_weight_update_communicator( rank_offset_count += engine.tp_size await asyncio.gather(*tasks) - async def update_named_weight(self, request: NamedWeightUpdateRequest): - return await self._run_on_all_engines("update_named_weight", request=request) + async def update_named_weights(self, request: NamedWeightsUpdateRequest): + return await self._run_on_all_engines("update_named_weights", request=request) async def reset_prefix_cache(self): return await self._run_on_all_engines("reset_prefix_cache") diff --git a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 204f4803da..229e506ff9 100644 --- a/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -8,7 +8,7 @@ InferenceEngineInterface, InferenceEngineInput, InferenceEngineOutput, - NamedWeightUpdateRequest, + NamedWeightsUpdateRequest, ) @@ -41,8 +41,8 @@ async def init_weight_update_communicator( master_addr, master_port, rank_offset, world_size, group_name, backend, override_existing ) - async def update_named_weight(self, request: NamedWeightUpdateRequest): - return await self.inference_engine_actor.update_named_weight.remote(request) + async def update_named_weights(self, request: NamedWeightsUpdateRequest): + return await self.inference_engine_actor.update_named_weights.remote(request) async def teardown(self): return await self.inference_engine_actor.teardown.remote() diff --git a/skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py b/skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py index 494021a17e..bb1767e68c 100644 --- a/skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/remote_inference_engine.py @@ -3,7 +3,7 @@ InferenceEngineInterface, InferenceEngineInput, InferenceEngineOutput, - NamedWeightUpdateRequest, + NamedWeightsUpdateRequest, ) from typing import List, Optional, Dict, Any import json @@ -103,25 +103,36 @@ async def init_weight_update_communicator( ) as response: return await response.json() - async def update_named_weight(self, request: NamedWeightUpdateRequest): - if request.get("extras") and "ipc_handles" in request["extras"]: + async def update_named_weights(self, request: NamedWeightsUpdateRequest): + if "names" not in request: + raise ValueError(f"Expected update weight request with 'names' entry, got keys: {request.keys()}") + + assert ( + len(request["names"]) == 1 + ), f"Remote inference engines support only requests with a single named weight at a time , got request with {len(request['names'])} entries" + + if request.get("extras") and "ipc_handles" in request["extras"][0]: raise ValueError( "Remote inference engines do not support CUDA IPC weight updates. Only local engines support IPC." ) if self.engine_backend == "vllm": - weight_update_method = "update_weight" + weight_update_method = "update_weights" elif self.engine_backend == "sglang": weight_update_method = "update_weights_from_distributed" else: raise ValueError(f"Invalid engine backend: {self.engine_backend}") async with aiohttp.ClientSession() as session: + name = request["names"][0] + dtype = request["dtypes"][0] + shape = request["shapes"][0] + resp = await session.post( f"{self.url}/{weight_update_method}", json={ - "name": request["name"], - "dtype": request["dtype"], - "shape": request["shape"], + "name": name, + "dtype": dtype, + "shape": shape, }, ) return await resp.json() diff --git a/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py b/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py index fb62939915..ce34d427ea 100644 --- a/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/sglang/sglang_engine.py @@ -29,7 +29,7 @@ InferenceEngineInterface, InferenceEngineInput, InferenceEngineOutput, - NamedWeightUpdateRequest, + NamedWeightsUpdateRequest, ) from skyrl_train.utils import torch_dtype_to_str @@ -104,12 +104,12 @@ def setup_envvars_for_sglang(kwargs, bundle_indices): os.environ["CUDA_VISIBLE_DEVICES"] = str(ray.get_gpu_ids()[0]) -def update_weight_cuda_ipc(model, named_tensors): +def update_weights_cuda_ipc(model, named_tensors): """ Custom weight loader for SGLang that handles IPC handles. This function is called by SGLang's model runner to load weights. - It reconstructs tensors from SkyRL's NamedWeightUpdateRequest that contains IPC handles + It reconstructs tensors from SkyRL's NamedWeightsUpdateRequest that contains IPC handles and loads them into the model. """ import torch @@ -128,36 +128,40 @@ def update_weight_cuda_ipc(model, named_tensors): request_data = tensor_bytes[:end_index] try: request_data_decoded = base64.b64decode(request_data) - request = pickle.loads(request_data_decoded) + request: NamedWeightsUpdateRequest = pickle.loads(request_data_decoded) except Exception as e: raise ValueError(f"Failed to deserialize request data: {e}") - # Extract the request data - ipc_handles = request["extras"]["ipc_handles"] - dtype = request["dtype"] - _ = request["shape"] - weight_name = request["name"] + weights_to_load = [] + for i in range(len(request["names"])): + # Extract the request data + ipc_handles = request["extras"][i]["ipc_handles"] + dtype = request["dtypes"][i] + _ = request["shapes"][i] + weight_name = request["names"][i] - device = torch.cuda.current_device() - props = torch.cuda.get_device_properties(device) - physical_gpu_id = str(props.uuid) + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + physical_gpu_id = str(props.uuid) - # Infer model dtype and device index from first parameter - model_dtype = torch_dtype_to_str(next(model.parameters()).dtype) - assert dtype == model_dtype, f"mismatch dtype: src {dtype}, dst {model_dtype}" - device_id = next(model.parameters()).device.index + # Infer model dtype and device index from first parameter + model_dtype = torch_dtype_to_str(next(model.parameters()).dtype) + assert dtype == model_dtype, f"mismatch dtype: src {dtype}, dst {model_dtype}" + device_id = next(model.parameters()).device.index - handle = ipc_handles[physical_gpu_id] - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - weight = func(*list_args) - model.load_weights([(weight_name, weight)]) + handle = ipc_handles[physical_gpu_id] + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + weight = func(*list_args) + weights_to_load.append((weight_name, weight)) + model.load_weights(weights_to_load) -CUSTOM_WEIGHT_LOADER_PATH = "skyrl_train.inference_engines.sglang.sglang_engine.update_weight_cuda_ipc" + +CUSTOM_WEIGHT_LOADER_PATH = "skyrl_train.inference_engines.sglang.sglang_engine.update_weights_cuda_ipc" class SGLangInferenceEngine(InferenceEngineInterface): @@ -256,12 +260,15 @@ async def init_weight_update_communicator( success, message = await self.engine.tokenizer_manager.init_weights_update_group(obj, None) return success, message - async def update_named_weight(self, request: NamedWeightUpdateRequest) -> Tuple[bool, str]: + async def update_named_weights(self, request: NamedWeightsUpdateRequest) -> Tuple[bool, str]: """Update named weights in SGLang engine.""" + if "names" not in request: + raise ValueError(f"Expected update weight request with 'names' entry, got keys: {request.keys()}") + extras = request.get("extras") - if extras is not None and "ipc_handles" in extras: + if extras is not None and "ipc_handles" in extras[0]: # CUDA IPC -- Here we reuse SGLang's update_weights_from_tensor, but actually load the - # weight from our request data. This will use the update_weight_cuda_ipc defined above. + # weight from our request data. This will use the update_weights_cuda_ipc defined above. # This is a bit hacky, but the only way as of now, since there is no other way to # write per-TP worker code besides using `custom_weight_loader`, unlike in vLLM we can # use `WorkerWrap`. @@ -293,14 +300,19 @@ async def update_named_weight(self, request: NamedWeightUpdateRequest) -> Tuple[ success, message = await self.engine.tokenizer_manager.update_weights_from_tensor(obj, None) return success, message else: + assert ( + len(request["names"]) == 1 + ), f"Update weights without cuda IPC only supports a single named weight at a time , got request with {len(request['names'])} entries" # Broadcast obj = UpdateWeightsFromDistributedReqInput( - name=request["name"], dtype=request["dtype"], shape=request["shape"] + name=request["names"][0], dtype=request["dtypes"][0], shape=request["shapes"][0] ) # Call the underlying async method for the same reason as in `init_weight_update_communicator` success, message = await self.engine.tokenizer_manager.update_weights_from_distributed(obj, None) - return success, message + if not success: + raise RuntimeError(f"Update weight request failed with message: {message}") + return async def wake_up(self, tags: Optional[List[str]] = None): """Wake up the engine. For multi-stage waking up, pass in `"weight"` or `"kv_cache"` to tags.""" diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index a61393f972..0e873bc4d8 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -15,7 +15,7 @@ InferenceEngineInterface, InferenceEngineInput, InferenceEngineOutput, - NamedWeightUpdateRequest, + NamedWeightsUpdateRequest, ) from skyrl_train.utils import str_to_torch_dtype @@ -95,37 +95,49 @@ def init_weight_update_communicator( f"rank={rank}, world_size={world_size}, group_name={group_name}", ) - def update_weight(self, name: str, dtype: str, shape: List[int]): + def update_weights(self, names: List[str], dtypes: List[str], shapes: List[List[int]]): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" - dtype = str_to_torch_dtype(dtype) - assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" - weight = torch.empty(shape, dtype=dtype, device="cuda") - torch.distributed.broadcast(weight, 0, group=self._model_update_group) + weight_list = [] + for name, dtype, shape in zip(names, dtypes, shapes): + dtype = str_to_torch_dtype(dtype) + assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" + weight = torch.empty(shape, dtype=dtype, device="cuda") + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + weight_list.append((name, weight)) + + self.model_runner.model.load_weights(weights=weight_list) + for weight in weight_list: + del weight + + def update_weights_cuda_ipc( + self, names: List[str], dtypes: List[str], shapes: List[int], ipc_handles: List[Dict[str, Any]] + ): - self.model_runner.model.load_weights(weights=[(name, weight)]) + weight_list = [] + for name, dtype, shape, ipc_handle in zip(names, dtypes, shapes, ipc_handles): - del weight + dtype = str_to_torch_dtype(dtype) + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + physical_gpu_id = str(props.uuid) - def update_weight_cuda_ipc(self, name: str, dtype: str, shape: List[int], ipc_handles: Dict[str, Any]): + assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" - dtype = str_to_torch_dtype(dtype) - device = torch.cuda.current_device() - props = torch.cuda.get_device_properties(device) - physical_gpu_id = str(props.uuid) + handle = ipc_handle[physical_gpu_id] - assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" + device_id = self.device.index + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + weight = func(*list_args) + weight_list.append((name, weight)) - handle = ipc_handles[physical_gpu_id] + self.model_runner.model.load_weights(weights=weight_list) - device_id = self.device.index - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - weight = func(*list_args) - self.model_runner.model.load_weights(weights=[(name, weight)]) - torch.cuda.synchronize() + for weight in weight_list: + del weight # TODO (sumanthrh): Add destroy process group RPC as a atexit handler to Trainer code. def destroy_weights_update_group(self): @@ -267,23 +279,32 @@ async def init_weight_update_communicator( args=(master_addr, master_port, rank_offset, world_size, group_name, backend, override_existing), ) - async def update_named_weight(self, request: NamedWeightUpdateRequest): + async def update_named_weights(self, request: NamedWeightsUpdateRequest): + if "names" not in request: + raise ValueError(f"Expected update weight request with 'names' entry, got keys: {request.keys()}") + + if not len(request["names"]): + raise ValueError("Update weight request should have atleast one entry in 'names'") + engine = self._get_engine() # Use IPC if handles are provided - if request.get("extras") and "ipc_handles" in request["extras"]: + if request.get("extras") and "ipc_handles" in request["extras"][0]: return await asyncio.to_thread( engine.collective_rpc, - "update_weight_cuda_ipc", + "update_weights_cuda_ipc", args=( - request["name"], - request["dtype"], - request["shape"], - request["extras"]["ipc_handles"], + request["names"], + request["dtypes"], + request["shapes"], + [extra["ipc_handles"] for extra in request["extras"]], ), ) else: + assert ( + len(request["names"]) == 1 + ), f"Update weights without cuda IPC only supports a single named weight at a time , got request with {len(request['names'])} entries" return await asyncio.to_thread( - engine.collective_rpc, "update_weight", args=(request["name"], request["dtype"], request["shape"]) + engine.collective_rpc, "update_weights", args=(request["names"], request["dtypes"], request["shapes"]) ) async def teardown(self): @@ -350,22 +371,39 @@ async def init_weight_update_communicator( args=(master_addr, master_port, rank_offset, world_size, group_name, backend, override_existing), ) - async def update_named_weight(self, request: NamedWeightUpdateRequest): + async def update_named_weights(self, request: NamedWeightsUpdateRequest): + if "names" not in request: + raise ValueError(f"Expected update weight request with 'names' entry, got keys: {request.keys()}") + + if not len(request["names"]): + raise ValueError("Update weight request should have atleast one entry in 'names'") + engine = self._get_engine() # Use IPC if handles are provided - if request.get("extras") and "ipc_handles" in request["extras"]: + + is_ipc = request.get("extras") and "ipc_handles" in request["extras"][0] + + if is_ipc: return await engine.collective_rpc( - "update_weight_cuda_ipc", + "update_weights_cuda_ipc", args=( - request["name"], - request["dtype"], - request["shape"], - request["extras"]["ipc_handles"], + request["names"], + request["dtypes"], + request["shapes"], + [extra["ipc_handles"] for extra in request["extras"]], ), ) else: + assert ( + len(request["names"]) == 1 + ), f"Update weights without cuda IPC only supports a single named weight at a time , got request with {len(request['names'])} entries" return await engine.collective_rpc( - "update_weight", args=(request["name"], request["dtype"], request["shape"]) + "update_weights", + args=( + request["names"], + request["dtypes"], + request["shapes"], + ), ) async def teardown(self): diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_server.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_server.py index 1622d1d295..aeda134be0 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_server.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_server.py @@ -98,15 +98,16 @@ async def _reset_prefix_cache(request: Request): await engine.reset_prefix_cache() return {"status": "ok"} - @app.post("/update_weight") - async def _update_weight(request: Request): + @app.post("/update_weights") + async def _update_weights(request: Request): data = await request.json() - name = data.get("name") - dtype = data.get("dtype") - shape = data.get("shape") + # engine expects a list of objects + names = [data.get("name")] + dtypes = [data.get("dtype")] + shapes = [data.get("shape")] await engine.collective_rpc( - "update_weight", - args=(name, dtype, shape), + "update_weights", + args=(names, dtypes, shapes), ) return {"status": "ok"} diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 77bb2b379a..283ed0472a 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -27,17 +27,16 @@ from skyrl_train.dataset.preprocess import ( convert_prompts_responses_to_batch_tensors, ) -from skyrl_train import utils +from skyrl_train.utils import ppo_utils from skyrl_train.utils import trainer_utils -from skyrl_train.utils import ( - Timer, +from skyrl_train.utils import Timer, get_ray_pg_ready_with_timeout +from skyrl_train.utils.ppo_utils import ( compute_approx_kl, masked_mean, - normalize_advantages_dict, - get_ray_pg_ready_with_timeout, get_kl_controller, FixedKLController, AdaptiveKLController, + normalize_advantages_dict, ) from skyrl_train.distributed.dispatch import MeshRank, concatenate_outputs_after_mesh_dispatch, ActorInfo from skyrl_train.workers.worker import PPORayActorGroup @@ -796,7 +795,7 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn # TODO (erictang000): we are just supporting custom rewards for now token_level_rewards = data["custom_rewards"] - advantages, returns = utils.compute_advantages_and_returns( + advantages, returns = ppo_utils.compute_advantages_and_returns( token_level_rewards=token_level_rewards, response_mask=data["response_mask"], index=data.metadata["uids"], diff --git a/skyrl-train/skyrl_train/utils/__init__.py b/skyrl-train/skyrl_train/utils/__init__.py index 4ecc09f4c4..3be533b78f 100644 --- a/skyrl-train/skyrl_train/utils/__init__.py +++ b/skyrl-train/skyrl_train/utils/__init__.py @@ -1,2 +1 @@ from .utils import * # noqa: F401, F403 -from .ppo_utils import * # noqa: F401, F403 diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index d68a3e70fb..57b74c062f 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -6,11 +6,6 @@ from loguru import logger from omegaconf import DictConfig, OmegaConf from ray.util.placement_group import placement_group, PlacementGroupSchedulingStrategy, PlacementGroup -from skyrl_train.utils.ppo_utils import ( - AdvantageEstimatorRegistry, - PolicyLossRegistry, - sync_registries, -) class Timer: @@ -120,6 +115,8 @@ def validate_batch_sizes(cfg: DictConfig): def validate_cfg(cfg: DictConfig): + from .ppo_utils import AdvantageEstimatorRegistry, PolicyLossRegistry + if cfg.generator.max_turns == 1: assert ( cfg.generator.max_input_length == cfg.trainer.max_prompt_length @@ -397,6 +394,10 @@ def initialize_ray(cfg: DictConfig): Args: cfg: Training config """ + from .ppo_utils import ( + sync_registries, + ) + env_vars = prepare_runtime_environment(cfg) ray.init(runtime_env={"env_vars": env_vars}) diff --git a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py index ca823683ca..a78429b18a 100644 --- a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py +++ b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py @@ -123,22 +123,22 @@ async def broadcast_to_inference_engines(self, inference_engine_client): torch.cuda.empty_cache() model = self.model.model.module - for name, param in model.named_parameters(): - # broadcast - if not self.use_cuda_ipc: + if not self.use_cuda_ipc: + for name, param in model.named_parameters(): if torch.distributed.get_rank() == 0: shape = param.shape if self.zero_stage != 3 else param.ds_shape update_weight_task = asyncio.create_task( - inference_engine_client.update_named_weight( + inference_engine_client.update_named_weights( { - "name": name, - "dtype": self.cfg.generator.model_dtype, - "shape": shape, + "names": [name], + "dtypes": [self.cfg.generator.model_dtype], + "shapes": [shape], } ) ) + # broadcast def gather_and_broadcast(param): # For ZeRO-3, allgather sharded parameter and broadcast to all InferenceEngines by rank 0 with deepspeed.zero.GatheredParameters([param], enabled=self.zero_stage == 3): @@ -149,11 +149,15 @@ def gather_and_broadcast(param): await asyncio.to_thread(gather_and_broadcast, param) if torch.distributed.get_rank() == 0: await update_weight_task + torch.distributed.barrier() + # CUDA IPC + else: + from torch.multiprocessing.reductions import reduce_tensor - # CUDA IPC - else: - from torch.multiprocessing.reductions import reduce_tensor + weights_update_request = {"names": [], "dtypes": [], "shapes": [], "extras": []} + current_size = 0 + for name, param in model.named_parameters(): # For ZeRO-3, allgather sharded parameter and broadcast to all InferenceEngines by rank 0 with deepspeed.zero.GatheredParameters([param], enabled=self.zero_stage == 3): weight = param.data.clone() @@ -171,22 +175,34 @@ def gather_and_broadcast(param): shape = param.shape if self.zero_stage != 3 else param.ds_shape - await asyncio.create_task( - inference_engine_client.update_named_weight( - { - "name": name, - "dtype": self.cfg.generator.model_dtype, - "shape": shape, - "extras": { - "ipc_handles": ipc_handles, - }, - } - ) + weights_update_request["names"].append(name) + weights_update_request["dtypes"].append(self.cfg.generator.model_dtype) + weights_update_request["shapes"].append(shape) + weights_update_request["extras"].append( + { + "ipc_handles": ipc_handles, + } ) + current_size += weight.nbytes + # We send in batches as an optimization + # sync if threshold is reached + if current_size / (1024**3) > self.cfg.generator.weight_transfer_threshold_cuda_ipc_GB: + await inference_engine_client.update_named_weights(weights_update_request) + current_size = 0 + weights_update_request = {"names": [], "dtypes": [], "shapes": [], "extras": []} + # force collect any sent tensors if possible to be memory efficient + torch.cuda.ipc_collect() torch.distributed.barrier() torch.cuda.synchronize() + # sync any remaining weights + if torch.distributed.get_rank() == 0 and len(weights_update_request["names"]) > 0: + await asyncio.create_task(inference_engine_client.update_named_weights(weights_update_request)) + current_size = 0 + weights_update_request = {"names": [], "dtypes": [], "shapes": [], "extras": []} + torch.distributed.barrier() + if cache_reset_task is not None: await cache_reset_task torch.cuda.empty_cache() diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index 0661dc82df..a05e209a55 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -107,22 +107,22 @@ async def broadcast_to_inference_engines(self, inference_engine_client): ) params = self.model.model.state_dict() - for name, param in params.items(): - # broadcast - if not self.use_cuda_ipc: + if not self.use_cuda_ipc: + for name, param in params.items(): if torch.distributed.get_rank() == 0: shape = param.shape update_weight_task = asyncio.create_task( - inference_engine_client.update_named_weight( + inference_engine_client.update_named_weights( { - "name": name, - "dtype": self.cfg.generator.model_dtype, - "shape": shape, + "names": [name], + "dtypes": [self.cfg.generator.model_dtype], + "shapes": [shape], } ) ) + # broadcast def gather_and_broadcast(param): # For FSDP, gather parameter and broadcast to all InferenceEngines by rank 0 device = torch.cuda.current_device() @@ -135,19 +135,22 @@ def gather_and_broadcast(param): await asyncio.to_thread(gather_and_broadcast, param) if torch.distributed.get_rank() == 0: await update_weight_task + torch.distributed.barrier() + # CUDA IPC + else: + weights_update_request = {"names": [], "dtypes": [], "shapes": [], "extras": []} + current_size = 0 - # CUDA IPC - else: + for name, param in params.items(): from torch.multiprocessing.reductions import reduce_tensor device = torch.cuda.current_device() param = param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param param = param.to(generator_dtype) - weight = param.data.clone() + weight = param.detach().contiguous() ipc_handle = reduce_tensor(weight) ipc_handle = {get_physical_gpu_id(): ipc_handle} - ipc_handle_list = [None] * torch.distributed.get_world_size() torch.distributed.all_gather_object(ipc_handle_list, ipc_handle) @@ -156,24 +159,31 @@ def gather_and_broadcast(param): for d in ipc_handle_list: ipc_handles.update(d) - shape = param.shape - - await asyncio.create_task( - inference_engine_client.update_named_weight( - { - "name": name, - "dtype": self.cfg.generator.model_dtype, - "shape": shape, - "extras": { - "ipc_handles": ipc_handles, - }, - } - ) - ) - + current_size += weight.nbytes + weights_update_request["names"].append(name) + weights_update_request["dtypes"].append(self.cfg.generator.model_dtype) + weights_update_request["shapes"].append(param.shape) + weights_update_request["extras"].append({"ipc_handles": ipc_handles}) + # We send in batches as an optimization + # sync if threshold is reached + if current_size / (1024**3) > self.cfg.generator.weight_transfer_threshold_cuda_ipc_GB: + await inference_engine_client.update_named_weights(weights_update_request) + + current_size = 0 + weights_update_request = {"names": [], "dtypes": [], "shapes": [], "extras": []} + # force collect any sent tensors if possible to be memory efficient + torch.cuda.ipc_collect() torch.distributed.barrier() torch.cuda.synchronize() + # sync any remaining weights + if len(weights_update_request["names"]) > 0 and torch.distributed.get_rank() == 0: + await asyncio.create_task(inference_engine_client.update_named_weights(weights_update_request)) + current_size = 0 + weights_update_request = {"names": [], "dtypes": [], "shapes": [], "extras": []} + torch.distributed.barrier() + torch.cuda.synchronize() + if cache_reset_task is not None: await cache_reset_task torch.cuda.empty_cache() diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 50e6ac82f0..c34eb741d8 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -16,7 +16,8 @@ from ray import ObjectRef from ray.util.placement_group import PlacementGroup, PlacementGroupSchedulingStrategy, placement_group -from skyrl_train.utils import masked_mean, ray_noset_visible_devices, get_ray_pg_ready_with_timeout +from skyrl_train.utils import ray_noset_visible_devices, get_ray_pg_ready_with_timeout +from skyrl_train.utils.ppo_utils import masked_mean from skyrl_train.distributed.dispatch import MeshRank, ActorInfo, DispatchRegistry, Dispatch from skyrl_train.distributed.strategy import DistributedStrategy from transformers import PreTrainedModel diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index 5f2b8eda56..a51c7954df 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -7,8 +7,7 @@ import pytest import torch from omegaconf import DictConfig -from skyrl_train.utils.ppo_utils import PolicyLossRegistry -from skyrl_train.utils import masked_mean +from skyrl_train.utils.ppo_utils import PolicyLossRegistry, masked_mean # Adapted a good test from NeMO-RL diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 87e7976d45..3b4047629d 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -179,7 +179,7 @@ def test_calculate_kl_create_experience_batched(dummy_config): assert metrics["avg_kl"] == approx(0.1249, abs=1e-4) -@patch("skyrl_train.utils.compute_advantages_and_returns", new_callable=MagicMock) +@patch("skyrl_train.utils.ppo_utils.compute_advantages_and_returns", new_callable=MagicMock) def test_calc_advantages_and_returns(mock_compute_adv_and_ret, dummy_config): trainer = RayPPOTrainer( cfg=dummy_config, diff --git a/skyrl-train/tests/gpu/test_grpo_sp_sanity.py b/skyrl-train/tests/gpu/test_grpo_sp_sanity.py index 175513e337..d5c9a3154b 100644 --- a/skyrl-train/tests/gpu/test_grpo_sp_sanity.py +++ b/skyrl-train/tests/gpu/test_grpo_sp_sanity.py @@ -12,7 +12,8 @@ from skyrl_train.trainer import RayPPOTrainer import ray from tqdm import tqdm -from skyrl_train.utils import Timer, normalize_advantages_dict +from skyrl_train.utils import Timer +from skyrl_train.utils.ppo_utils import normalize_advantages_dict import asyncio