diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 6a9695ebd..e570e6df4 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1,65 +1,77 @@ +import copy import math import os import shutil -from typing import Any, List, Optional, Dict, Tuple, Union -from jaxtyping import Float +from collections import defaultdict from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np import ray -from ray import ObjectRef import torch +from jaxtyping import Float from loguru import logger from omegaconf import DictConfig +from ray import ObjectRef from ray.util.placement_group import PlacementGroup, placement_group from tqdm import tqdm from transformers import AutoTokenizer -import numpy as np -from collections import defaultdict from skyrl_train.dataset import PromptDataset -from skyrl_train.utils.tracking import Tracking -from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch +from skyrl_train.dataset.preprocess import ( + convert_prompts_responses_to_batch_tensors, +) +from skyrl_train.distributed.dispatch import ( + ActorInfo, + MeshRank, +) +from skyrl_train.evaluate import evaluate, evaluate_step_wise from skyrl_train.generators.base import ( GeneratorInput, - GeneratorOutput, GeneratorInterface, + GeneratorOutput, ) -import copy -from skyrl_train.generators.utils import get_metrics_from_generator_output, prepare_generator_input -from skyrl_train.dataset.preprocess import ( - convert_prompts_responses_to_batch_tensors, +from skyrl_train.generators.utils import ( + get_metrics_from_generator_output, + prepare_generator_input, +) +from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl_train.inference_engines.utils import get_sampling_params_for_backend +from skyrl_train.training_batch import TrainingInputBatch +from skyrl_train.utils import ( + Timer, + get_ray_pg_ready_with_timeout, + ppo_utils, + trainer_utils, ) -from skyrl_train.utils import ppo_utils, trainer_utils -from skyrl_train.utils.io import io -from skyrl_train.utils import Timer, get_ray_pg_ready_with_timeout from skyrl_train.utils.constants import SKYRL_RAY_PG_TIMEOUT_IN_S +from skyrl_train.utils.io import io +from skyrl_train.utils.logging_utils import log_example from skyrl_train.utils.ppo_utils import ( + AdaptiveKLController, + FixedKLController, compute_approx_kl, - masked_mean, get_kl_controller, - FixedKLController, - AdaptiveKLController, + masked_mean, normalize_advantages_dict, ) -from skyrl_train.distributed.dispatch import MeshRank, concatenate_outputs_after_mesh_dispatch, ActorInfo -from skyrl_train.workers.worker import PPORayActorGroup -from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient -from skyrl_train.inference_engines.utils import get_sampling_params_for_backend +from skyrl_train.utils.tracking import Tracking from skyrl_train.utils.trainer_utils import ( + GLOBAL_STEP_PREFIX, + DynamicSamplingState, + ResumeMode, + build_dataloader, cleanup_old_checkpoints, - run_on_each_node, - get_node_ids, extract_step_from_path, + run_on_each_node, validate_consistency_for_latest_checkpoint, validate_generator_output, - GLOBAL_STEP_PREFIX, - ResumeMode, - DynamicSamplingState, - build_dataloader, zero_variance_filter, ) from skyrl_train.utils.utils import configure_ray_worker_logging -from skyrl_train.evaluate import evaluate, evaluate_step_wise -from skyrl_train.utils.logging_utils import log_example +from skyrl_train.workers.worker import PPORayActorGroup +from skyrl_train.workers.worker_dispatch import WorkerDispatch +from skyrl_train.workers.worker_utils import reduce_metrics class RayPPOTrainer: @@ -107,8 +119,14 @@ def __init__( self.dynamic_sampling_state: Optional[DynamicSamplingState] = None self.reward_kl_controller: Optional[Union[FixedKLController, AdaptiveKLController]] = None + self.dispatch: WorkerDispatch = None configure_ray_worker_logging() + @property + def has_critic(self) -> bool: + """Check if critic model is configured.""" + return self.cfg.trainer.critic.model.path is not None + def _build_train_dataloader_and_compute_training_steps(self): """ Hook for constructing the training dataloader. Subclasses can override @@ -156,23 +174,18 @@ async def train(self): with Timer("init_weight_sync_state"): self.init_weight_sync_state() - # Load policy model to GPU before loading checkpoint. - if self.colocate_all: - self.policy_model.backload_to_gpu() - # Load checkpoint state if resumption is enabled. if self.resume_mode != ResumeMode.NONE: with Timer("load_checkpoints"): self.global_step, _ = self.load_checkpoints() + self.dispatch.prepare_for_weight_sync() if self.colocate_all: - self.policy_model.offload_to_cpu(offload_optimizer=True, offload_model=False) await self.inference_engine_client.wake_up(tags=["weights"]) with Timer("sync_weights"): - ray.get(self.sync_policy_weights_to_inference_engines()) + self.dispatch.broadcast_to_inference_engines(self.inference_engine_client) + self.dispatch.finish_weight_sync() if self.colocate_all: - with Timer("offload_policy_model_to_cpu"): - self.policy_model.offload_to_cpu(offload_optimizer=False, offload_model=True) await self.inference_engine_client.wake_up(tags=["kv_cache"]) # Eval before training @@ -295,14 +308,13 @@ async def train(self): self.update_ref_with_policy() # 7. sync weights to inference engines + self.dispatch.prepare_for_weight_sync() if self.colocate_all: - self.policy_model.offload_to_cpu(offload_optimizer=True, offload_model=False) await self.inference_engine_client.wake_up(tags=["weights"]) with Timer("sync_weights", self.all_timings): - ray.get(self.sync_policy_weights_to_inference_engines()) + self.dispatch.broadcast_to_inference_engines(self.inference_engine_client) + self.dispatch.finish_weight_sync() if self.colocate_all: - with Timer("offload_policy_model_to_cpu"): - self.policy_model.offload_to_cpu(offload_optimizer=False, offload_model=True) await self.inference_engine_client.wake_up(tags=["kv_cache"]) # 8. set logs @@ -335,7 +347,6 @@ async def train(self): pbar.close() if self.colocate_all: await self.inference_engine_client.sleep() - self.policy_model.backload_to_gpu() if self.cfg.trainer.ckpt_interval > 0: with Timer("save_checkpoints", self.all_timings): self.save_checkpoints() @@ -353,11 +364,7 @@ def _remove_tail_data(self, entries: List[Any]) -> List[Any]: training we care that the total number of samples is nicely splittable across the (combined) data-parallel size of all enabled models. """ - lcm_dp_size = self.policy_model.actor_infos[0].rank.dp_size - if self.critic_model is not None: - lcm_dp_size = math.lcm(lcm_dp_size, self.critic_model.actor_infos[0].rank.dp_size) - if self.ref_model is not None: - lcm_dp_size = math.lcm(lcm_dp_size, self.ref_model.actor_infos[0].rank.dp_size) + lcm_dp_size = self.dispatch.get_lcm_dp_size() n_samples_per_prompt = self.cfg.generator.n_samples_per_prompt @@ -549,6 +556,18 @@ def build_models(self, PolicyWorker, CriticWorker, RefWorker): self.critic_model: Optional[PPORayActorGroup] = critic_model self.ref_model: Optional[PPORayActorGroup] = ref_model + # Create unified dispatch that manages all actor groups + self.dispatch = WorkerDispatch( + cfg=self.cfg, + policy_actor_group=policy_model, + critic_actor_group=critic_model, + ref_actor_group=ref_model, + ) + + # Mark all models as offloaded if colocate_all (they were offloaded above) + if self.colocate_all: + self.dispatch.mark_all_offloaded() + logger.info("init policy/ref/critic models done") def init_weight_sync_state(self): @@ -841,12 +860,7 @@ def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch: """Pad the batch to be divisible by dp size""" import math - dp_size = self.policy_model.actor_infos[0].rank.dp_size - if self.critic_model is not None: - dp_size = math.lcm(dp_size, self.critic_model.actor_infos[0].rank.dp_size) - if self.ref_model is not None: - dp_size = math.lcm(dp_size, self.ref_model.actor_infos[0].rank.dp_size) - + dp_size = self.dispatch.get_lcm_dp_size() pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size new_tensors = {} training_input.metadata["pad_size"] = pad_size @@ -885,8 +899,9 @@ def fwd_logprobs_values_reward( training_input: TrainingInputBatch, ): """ - Calculate values from the critic, log probs from the policy and ref model, and rewards from the reward model - and then calculate the kl divergence between the action log probs and the base action log probs. + Calculate values from the critic, log probs from the policy and ref model. + + Dispatch handles offload/backload automatically for all colocation configurations. Expects: - `["sequences"]`: Integer[torch.Tensor, "batch_size seqlen"] @@ -900,81 +915,27 @@ def fwd_logprobs_values_reward( """ data_fwd_pass = training_input.select(keys=["sequences", "attention_mask"], metadata_keys=["response_length"]) - def collect_results(actor_infos, results, key): - ret_outputs: TrainingOutputBatch = concatenate_outputs_after_mesh_dispatch(actor_infos, results) - return ret_outputs[key] - + values = None base_log_probs = None action_log_probs = None - values = None - - # calculate critic values - if self.colocate_all and self.critic_model is not None: - self.critic_model.backload_to_gpu(backload_optimizer=False, backload_model=True) - if self.critic_model is not None: - value_refs = self.critic_model.async_run_ray_method("mesh", "forward", data=data_fwd_pass) - if self.colocate_all: - all_rank_values = ray.get(value_refs) - values = collect_results(self.critic_model.actor_infos, all_rank_values, key="output") - self.critic_model.offload_to_cpu(offload_optimizer=False, offload_model=True) + # Critic forward (dispatch handles offload/backload automatically) + if self.has_critic: + critic_output = self.dispatch.forward("critic", data_fwd_pass) + values = critic_output["output"] - # calculate ref log probs + # Ref forward if self.ref_model is not None: - if self.cfg.trainer.placement.colocate_policy_ref or self.colocate_all: - self.ref_model.backload_to_gpu() + ref_output = self.dispatch.forward("ref", data_fwd_pass) + base_log_probs = ref_output["output"] + self.dispatch.empty_cache("ref") - base_action_log_probs_refs = self.ref_model.async_run_ray_method("mesh", "forward", data=data_fwd_pass) + # Policy forward + policy_output = self.dispatch.forward("policy", data_fwd_pass) + action_log_probs = policy_output["output"] - if self.ref_model is not None: - # handle colocate policy and ref model - if self.cfg.trainer.placement.colocate_policy_ref or self.colocate_all: - all_rank_base_log_probs: List[TrainingOutputBatch] = ray.get(base_action_log_probs_refs) - base_log_probs = collect_results(self.ref_model.actor_infos, all_rank_base_log_probs, key="output") - self.ref_model.offload_to_cpu() - ray.get(self.ref_model.async_run_ray_method("pass_through", "empty_cache")) - else: - base_log_probs = None - - # calculate action log probs - if self.colocate_all: - self.policy_model.backload_to_gpu(backload_optimizer=False, backload_model=True) - - action_log_probs_refs = self.policy_model.async_run_ray_method("mesh", "forward", data=data_fwd_pass) - if self.colocate_all: - all_rank_action_log_probs: List[TrainingOutputBatch] = ray.get(action_log_probs_refs) - action_log_probs = collect_results(self.policy_model.actor_infos, all_rank_action_log_probs, key="output") - self.policy_model.offload_to_cpu(offload_optimizer=False, offload_model=True) - - # wait all models done - # if not colocate_policy_ref, then need to gather base_log_probs - # if self.critic_model is not None, then need to gather value - if not self.colocate_all: - if not self.cfg.trainer.placement.colocate_policy_ref: - if self.critic_model is not None: - all_rank_values = ray.get(value_refs) - values = collect_results(self.critic_model.actor_infos, all_rank_values, key="output") - - if self.ref_model is not None: - all_rank_base_log_probs: List[TrainingOutputBatch] = ray.get(base_action_log_probs_refs) - base_log_probs = collect_results(self.ref_model.actor_infos, all_rank_base_log_probs, key="output") - else: - base_log_probs = None - - elif self.critic_model is not None: - all_rank_values = ray.get(value_refs) - values = collect_results(self.critic_model.actor_infos, all_rank_values, key="output") - - all_rank_action_log_probs: List[TrainingOutputBatch] = ray.get(action_log_probs_refs) - action_log_probs = collect_results(self.policy_model.actor_infos, all_rank_action_log_probs, key="output") - - if not self.colocate_all: - empty_cache_refs = self.policy_model.async_run_ray_method("pass_through", "empty_cache") - if self.ref_model is not None: - empty_cache_refs.extend(self.ref_model.async_run_ray_method("pass_through", "empty_cache")) - if self.critic_model is not None: - empty_cache_refs.extend(self.critic_model.async_run_ray_method("pass_through", "empty_cache")) - ray.get(empty_cache_refs) + # Empty cache after all forward passes + self.dispatch.empty_cache() sequences_all: torch.Tensor = training_input["sequences"] # NOTE (sumanthrh): The slicing is needed to make sure that the batch dimension doesn't change for the tensordict. @@ -1065,43 +1026,84 @@ def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]: "pass_through", "broadcast_to_inference_engines", self.inference_engine_client ) + def _execute_training_step(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: + """ + Execute training step for FSDP strategy using forward_backward + optim_step. + + The trainer loops over epochs and mini-batches. Workers handle micro-batching + internally for gradient accumulation (memory efficiency). + + Args: + model: Model name ("policy" or "critic") + data: Training data batch + + Returns: + Dict of reduced metrics from training + """ + # Compute mini batch size from config (algorithm-level concept) + n_samples = self.cfg.generator.n_samples_per_prompt + if model == "policy": + mini_batch_size = self.cfg.trainer.policy_mini_batch_size * n_samples + else: + mini_batch_size = self.cfg.trainer.critic_mini_batch_size * n_samples + + all_metrics: Dict[str, List[float]] = defaultdict(list) + + # Training loop over epochs and mini-batches + for _epoch in range(self.cfg.trainer.update_epochs_per_batch): + num_mini_batches = len(data) // mini_batch_size + for local_step in range(num_mini_batches): + start_idx = local_step * mini_batch_size + end_idx = (local_step + 1) * mini_batch_size + mini_batch = data[start_idx:end_idx] + + status = self.dispatch.forward_backward(model, mini_batch) + for k, v in status.items(): + all_metrics[k].append(v) + + # Optimizer step after each mini batch + grad_norm = self.dispatch.optim_step(model) + if grad_norm is not None: + all_metrics["grad_norm"].append(grad_norm) + + # Reduce metrics across all mini-batches and epochs + reduced_metrics = reduce_metrics(all_metrics) + return reduced_metrics + def train_critic_and_policy(self, data: TrainingInputBatch): """ - Run the training step for the policy and critic models (this is overlapped if colocate_all is False). + Run the training step for the policy and critic models. + + For Megatron strategy: uses ppo_train (training loop inside worker) + For FSDP strategy: uses forward_backward + optim_step (training loop in trainer) """ data.metadata["global_step"] = self.global_step - if self.colocate_all: - if self.critic_model is not None: + critic_status = None + + if self.cfg.trainer.strategy == "megatron": + # Megatron: training loop inside worker via ppo_train + if self.has_critic: with Timer("critic_train", self.all_timings): - self.critic_model.backload_to_gpu() - critic_statuses = ray.get(self.critic_model.async_run_ray_method("mesh", "ppo_train", data)) - self.critic_model.offload_to_cpu() + critic_status = self.dispatch.ppo_train("critic", data) with Timer("policy_train", self.all_timings): - self.policy_model.backload_to_gpu() - policy_statuses = ray.get(self.policy_model.async_run_ray_method("mesh", "ppo_train", data)) + policy_status = self.dispatch.ppo_train("policy", data) else: - if self.critic_model is not None: - with Timer("policy_critic_overlap_train", self.all_timings): - policy_refs = self.policy_model.async_run_ray_method("mesh", "ppo_train", data) - critic_refs = self.critic_model.async_run_ray_method("mesh", "ppo_train", data) - policy_statuses = ray.get(policy_refs) - critic_statuses = ray.get(critic_refs) - else: - with Timer("policy_train", self.all_timings): - policy_statuses = ray.get(self.policy_model.async_run_ray_method("mesh", "ppo_train", data)) + # FSDP: training loop in trainer via forward_backward + optim_step + if self.has_critic: + with Timer("critic_train", self.all_timings): + critic_status = self._execute_training_step("critic", data) + with Timer("policy_train", self.all_timings): + policy_status = self._execute_training_step("policy", data) - empty_cache_refs = [] - if self.critic_model is not None: - critic_status = critic_statuses[0].metadata["train_status"] + # Update metrics + if critic_status is not None: for k, v in critic_status.items(): self.all_metrics.update({f"critic/{k}": v}) - empty_cache_refs += self.critic_model.async_run_ray_method("pass_through", "empty_cache") - policy_status = policy_statuses[0].metadata["train_status"] for k, v in policy_status.items(): self.all_metrics.update({f"policy/{k}": v}) - empty_cache_refs += self.policy_model.async_run_ray_method("pass_through", "empty_cache") - ray.get(empty_cache_refs) + + self.dispatch.empty_cache() return policy_status @@ -1180,7 +1182,7 @@ def save_checkpoints(self): """ Save the model, optimizer, and training states to disk. - If colocate_all is True, assumes that the policy model is currently on GPU. + Dispatch handles offload/backload automatically for all colocation configurations. """ # Create global step folder structure global_step_folder = os.path.join(self.cfg.trainer.ckpt_path, f"global_step_{self.global_step}") @@ -1189,34 +1191,12 @@ def save_checkpoints(self): io.makedirs(global_step_folder, exist_ok=True) - # Save policy checkpoint - ray.get( - self.policy_model.async_run_ray_method( - "pass_through", - "save_checkpoint", - ckpt_dir=policy_save_dir, - tokenizer=self.tokenizer, - ) - ) + # Save policy checkpoint (dispatch handles offload/backload) + self.dispatch.save_checkpoint("policy", policy_save_dir, self.tokenizer) # Save critic checkpoint (if it exists) - if self.critic_model is not None: - if self.colocate_all: - self.policy_model.offload_to_cpu() - self.critic_model.backload_to_gpu() - - ray.get( - self.critic_model.async_run_ray_method( - "pass_through", - "save_checkpoint", - ckpt_dir=critic_save_dir, - tokenizer=self.tokenizer, - ) - ) - - if self.colocate_all: - self.critic_model.offload_to_cpu() - self.policy_model.backload_to_gpu() + if self.has_critic: + self.dispatch.save_checkpoint("critic", critic_save_dir, self.tokenizer) # Save dataloader state dataloader_save_path = os.path.join(global_step_folder, "data.pt") @@ -1251,7 +1231,7 @@ def save_checkpoints(self): def _cleanup_old_checkpoints(self): if not self._node_ids: - self._node_ids = get_node_ids(self.policy_model, self.critic_model, self.ref_model) + self._node_ids = self.dispatch.get_node_ids() run_on_each_node( self._node_ids, cleanup_old_checkpoints, @@ -1351,30 +1331,24 @@ def load_checkpoints(self) -> Tuple[int, str]: f"No dataloader state found at {dataloader_state_path}. Dataloader will start from beginning." ) - # 3. Load policy checkpoint + # 3. Load policy checkpoint (dispatch handles offload/backload) logger.info(f"Loading policy checkpoint from {policy_ckpt_dir}") - _ = ray.get( - self.policy_model.async_run_ray_method( - "pass_through", - "load_checkpoint", - ckpt_dir=policy_ckpt_dir, - load_optimizer_states=True, - load_lr_scheduler_states=True, - ) + self.dispatch.load_checkpoint( + "policy", + policy_ckpt_dir, + load_optimizer_states=True, + load_lr_scheduler_states=True, ) logger.info("Successfully loaded policy checkpoint") # 4. Load critic checkpoint if it exists and we have a critic model - if self.critic_model is not None: + if self.has_critic: logger.info(f"Loading critic checkpoint from {critic_ckpt_dir}") - _ = ray.get( - self.critic_model.async_run_ray_method( - "pass_through", - "load_checkpoint", - ckpt_dir=critic_ckpt_dir, - load_optimizer_states=True, - load_lr_scheduler_states=True, - ) + self.dispatch.load_checkpoint( + "critic", + critic_ckpt_dir, + load_optimizer_states=True, + load_lr_scheduler_states=True, ) logger.info("Successfully loaded critic checkpoint") @@ -1384,41 +1358,33 @@ def load_checkpoints(self) -> Tuple[int, str]: def save_models(self): """ Save the model parameters in HF format at `cfg.trainer.export_path`. + + Dispatch handles offload/backload automatically for all colocation configurations. """ policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy") - ray.get( - self.policy_model.async_run_ray_method("pass_through", "save_hf_model", policy_export_dir, self.tokenizer) - ) - if self.critic_model is not None: + self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer) + + if self.has_critic: critic_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "critic") - ray.get( - self.critic_model.async_run_ray_method( - "pass_through", "save_hf_model", critic_export_dir, self.tokenizer - ) - ) + self.dispatch.save_hf_model("critic", critic_export_dir, self.tokenizer) + logger.info("Successfully saved model weights.") def update_ref_with_policy(self): """ - Update the reference model with the policy model weights (required by some algorithms) + Update the reference model with the policy model weights (required by some algorithms). - If colocate_all is enabled: - - before calling this method, the policy model should be on GPU, and inference engine should be asleep / on CPU. - - after calling this method, the same model placement still holds. + Dispatch handles offload/backload automatically for all colocation configurations. + After this method, prepare_for_weight_sync() should be called to ensure policy is on GPU. """ # TODO(tgriggs): Make policy-to-ref sync faster. policy_export_dir = os.path.join(self.cfg.trainer.export_path, f"global_step_{self.global_step}", "policy") - ray.get( - self.policy_model.async_run_ray_method("pass_through", "save_hf_model", policy_export_dir, self.tokenizer) - ) - # NOTE (sumanthrh): This is for the memory efficient case where we can't keep policy and ref model state on GPU together - # We thus offload the policy model to CPU and then load the ref model from the policy model checkpoint, and then backload the policy model to GPU - if self.colocate_all: - self.policy_model.offload_to_cpu() - ray.get(self.ref_model.async_init_model(policy_export_dir)) - if self.colocate_all: - self.ref_model.offload_to_cpu() - self.policy_model.backload_to_gpu() + + # Save policy model (dispatch handles GPU state) + self.dispatch.save_hf_model("policy", policy_export_dir, self.tokenizer) + + # Re-initialize ref model from saved policy (dispatch handles offloading policy first) + self.dispatch.init_model("ref", policy_export_dir) # Clean up temporary saved model files try: @@ -1427,4 +1393,4 @@ def update_ref_with_policy(self): except Exception as e: logger.warning(f"Failed to clean up temporary policy export directory {policy_export_dir}: {e}") - logger.info("Successfully update ref model with policy model, training continue.") + logger.info("Successfully updated ref model with policy model, training continues.") diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 080b9bb42..53bb6a709 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -431,6 +431,20 @@ def _broadcast_no_grad(*args, **kwargs): pp_size=mpu.get_pipeline_model_parallel_world_size(), ) + def _normalize_mini_batch_size(self): + """ + Override to set Megatron-specific batch size attributes. + + Megatron's ppo_train method needs policy_mini_batch_size_per_gpu to compute + how many micro batches fit in a mini batch for gradient accumulation. + """ + super()._normalize_mini_batch_size() # Sets _micro_batches_accumulated + + # Megatron-specific: compute mini batch size per GPU for ppo_train + n_samples = self.cfg.generator.n_samples_per_prompt + dp_size = self.mesh_rank.dp_size + self.policy_mini_batch_size_per_gpu = (self.cfg.trainer.policy_mini_batch_size * n_samples) // dp_size + def init_model(self, model_path, num_training_steps: int = 1e9): """ Initialize the model, optimizer, and scheduler for the policy worker. @@ -537,8 +551,8 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # TODO: Convert this into 2 loops for minibatches and microbatches. micro_buffer = [] - for local_step, microbatch in enumerate(pbar): - experience = BatchIterator.batch_to_experience(microbatch) + for local_step, experience in enumerate(pbar): + # BatchIterator now yields Experience objects directly experience.to_device(torch.cuda.current_device()) sequences = experience.sequences attention_mask = experience.attention_mask diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 361a0777e..912eb2dc3 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -2,18 +2,18 @@ import logging import os import socket -from datetime import timedelta -from typing import Dict, Optional, Type, List, Any, Callable -from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p -from tqdm import tqdm from collections import defaultdict +from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p +from datetime import timedelta +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Type import ray import torch -import torch.nn as nn -from torch.optim.lr_scheduler import LRScheduler -from torch.optim import Optimizer import torch.distributed +import torch.nn as nn +from loguru import logger +from omegaconf import DictConfig from ray import ObjectRef from ray.util.placement_group import ( PlacementGroup, @@ -21,24 +21,42 @@ placement_group, placement_group_table, ) - -from skyrl_train.utils import ray_noset_visible_devices, get_ray_pg_ready_with_timeout, get_reordered_bundle_indices -from skyrl_train.utils.constants import SKYRL_RAY_PG_TIMEOUT_IN_S, SKYRL_WORKER_NCCL_TIMEOUT_IN_S -from skyrl_train.utils.io import io -from skyrl_train.utils.ppo_utils import masked_mean -from skyrl_train.distributed.dispatch import MeshRank, ActorInfo, DispatchRegistry, Dispatch -from skyrl_train.distributed.strategy import DistributedStrategy +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler from transformers import PreTrainedModel -from loguru import logger -from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch -from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics + from skyrl_train.dataset.replay_buffer import Experience -from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch +from skyrl_train.distributed.dispatch import ( + ActorInfo, + Dispatch, + DispatchRegistry, + MeshRank, +) +from skyrl_train.distributed.strategy import DistributedStrategy +from skyrl_train.distributed.ulysses import ( + apply_monkey_patch, + set_ulysses_sequence_parallel_group, +) from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient +from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch +from skyrl_train.utils import ( + get_ray_pg_ready_with_timeout, + get_reordered_bundle_indices, + ray_noset_visible_devices, +) +from skyrl_train.utils.constants import ( + SKYRL_RAY_PG_TIMEOUT_IN_S, + SKYRL_WORKER_NCCL_TIMEOUT_IN_S, +) +from skyrl_train.utils.io import io +from skyrl_train.utils.ppo_utils import ( + PolicyLossRegistry, + compute_approx_kl, + masked_mean, + ppo_critic_loss, +) from skyrl_train.utils.utils import configure_ray_worker_logging -from omegaconf import DictConfig -from pathlib import Path +from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics _SET_AFFINITY = False @@ -234,21 +252,33 @@ def get_cuda_memory(self) -> Dict[str, Any]: "total": total, } - def save_memory_snapshot(self, global_step=None, local_step=None): + def save_memory_snapshot(self, tag: str = ""): """Save a snapshot of memory usage on the Worker's CUDA device. + No-ops if record_memory is False. + + Args: + tag: Label for the snapshot (e.g., "forward_backward", "optim_step") + .. note:: This function should be called on all the ranks in the worker group simultaneously. """ + if not self.record_memory: + return + + # Track snapshot count for unique filenames + if not hasattr(self, "_snapshot_count"): + self._snapshot_count = 0 + self._snapshot_count += 1 + rank = torch.distributed.get_rank() save_path = os.path.join(self.cfg.trainer.ckpt_path, "memory_snapshots") if self._local_rank == 0 and not io.exists(save_path): io.makedirs(save_path, exist_ok=True) torch.distributed.barrier() - if global_step is None or local_step is None: - file_name = f"policy_rank_{rank}.pickle" - else: - file_name = f"policy_rank_{rank}_training_step_{global_step}_{local_step}.pickle" + + tag_str = f"_{tag}" if tag else "" + file_name = f"rank_{rank}{tag_str}_{self._snapshot_count}.pickle" record_memory_path = os.path.join(save_path, file_name) if io.exists(record_memory_path): # seeing issues if we don't remove the file first @@ -612,30 +642,60 @@ def __init__(self, **kwargs): def _normalize_mini_batch_size(self): """ - Normalize mini batch sizes to per-gpu mini batch sizes.. + Initialize micro batch tracking for gradient accumulation. + + The worker no longer needs to know mini batch size - it processes whatever + batch it receives, breaking it into micro batches. Gradient scaling happens + at optim_step time based on how many micro batches were accumulated. + + TODO: Rename to _init_gradient_accumulation_state once Megatron no longer + requires mini-batch normalization in its override. The name is kept for + backwards compatibility with Megatron which still does actual normalization. """ if not hasattr(self, "mesh_rank") or self.mesh_rank is None: raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()") - dp_size = self.mesh_rank.dp_size - self.policy_mini_batch_size_per_gpu = ( - self.cfg.trainer.policy_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size - ) + # Track micro batches for gradient scaling at optim_step + self._micro_batches_accumulated = 0 + + def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: + """ + Perform forward and backward passes for a batch, handling micro-batching internally. - def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]: - """Perform the forward and backward pass for one micro-batch. + The batch is split into micro batches based on micro_train_batch_size_per_gpu. + Gradients accumulate across micro batches. Gradient scaling happens at optim_step. Args: - experience: The microbatch data to run the forward and backward pass on. - microbatch_weight: Weight of the microbatch, used to scale the loss contribution - for the microbatch. For example, if you accumulate gradients over 2 microbatches, - then each microbatch should have a weight of 1/2. + data: TrainingInputBatch (already DP-sharded by WorkerDispatch/MeshDispatch) Returns: - Dict containing the status (including loss and some other metrics) - for the microbatch. + Aggregated metrics dict across all micro batches + """ + micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu + all_metrics = defaultdict(list) + + for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): + metrics = self._forward_backward_micro(micro_batch) + self._micro_batches_accumulated += 1 + for k, v in metrics.items(): + all_metrics[k].append(v) + + return reduce_metrics(dict(all_metrics)) + + def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: + """ + Perform forward and backward pass for one micro batch. + + Loss is not scaled here - gradient scaling happens at optim_step time. + + Args: + experience: Experience object for one micro batch + + Returns: + All-reduced metrics dict for this micro batch """ self.model.train() + experience.to_device(torch.cuda.current_device()) sequences = experience.sequences @@ -698,7 +758,6 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef loss = policy_loss + kl_loss_term - entropy_loss_term - loss = loss * microbatch_weight self.strategy.backward(loss, self.model, self.optimizer) status = { @@ -707,124 +766,57 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> "ppo_clip_ratio": clip_ratio, "policy_entropy": entropy.item(), "response_length": num_actions, + "policy_lr": self.scheduler.get_last_lr()[0], } if self.cfg.trainer.algorithm.use_kl_loss: status["policy_kl"] = kl_loss.item() + # All-reduce metrics across DP workers + status = self.strategy.all_reduce(status) + return status def optim_step(self) -> float: """ - Perform optimizer step and return the gradient norm. + Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. + + Returns: + The gradient norm (before scaling, after clipping) """ + # Scale accumulated gradients by 1/N to get correct average + if self._micro_batches_accumulated > 0: + scale = 1.0 / self._micro_batches_accumulated + for param in self.model.parameters(): + if param.grad is not None: + param.grad.mul_(scale) + + # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") + + # Reset counter for next accumulation cycle + self._micro_batches_accumulated = 0 + if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item() return grad_norm - def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: - global_step = train_data.metadata["global_step"] - minibatch_iterator = BatchIterator( - train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False - ) - - status_list = [] - all_metrics = defaultdict(list) - num_minibatches = len(minibatch_iterator) - local_step = 0 - - def record_status(status: Dict[str, float]): - """Record the aggregated (all-reduced) training status for the latest microbatch. - Also, update the progress bar with the latest status.""" - status["policy_lr"] = self.scheduler.get_last_lr()[0] - - # for DP - # TODO (sumanthrh): this assumes all workers are data parallel. - # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. - status = self.strategy.all_reduce(status) - - # weighted mean for kl - # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. - # we can log this in the driver - # if "kl" in status: - # status["kl"] *= status["response_length"] - # status["kl"] /= status["response_length"] - - short_status = {} - - if "policy_loss" in status: - short_status = { - "pg": status["policy_loss"], - "glen": status["response_length"], - "policy_lr": status["policy_lr"], - "ent": status["policy_entropy"], - } - if "raw_grad_norm" in status: - short_status["grad_norm"] = status["raw_grad_norm"] - if "reward" in status: - short_status["rm"] = status["reward"] - - if "critic_loss" in status: - short_status["cri"] = status["critic_loss"] - short_status["vals"] = status["values"] - short_status["cri_lr"] = status["critic_lr"] - - if "ptx_loss" in status: - short_status["ptx"] = status["ptx_loss"] - - status_list.append(status) - for k, v in status.items(): - all_metrics[k].append(v) - minibatch_pbar.set_postfix(short_status) - - for epoch in range(self.cfg.trainer.update_epochs_per_batch): - minibatch_pbar = tqdm( - minibatch_iterator, - desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", - disable=not self.strategy.is_rank_0(), - ) - for minibatch in minibatch_pbar: - microbatch_iterator = BatchIterator( - minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False - ) - num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches - - for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = BatchIterator.batch_to_experience(microbatch) - status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) - - # Record status for all but the last microbatch in the minibatch. - # The last microbatch should be recorded after the optimizer step. - if microbatch_idx < num_microbatches - 1: - if self.record_memory: - self.save_memory_snapshot(global_step, local_step) - record_status(status) - - # Local step counts the number of processed microbatches. - local_step += 1 - - grad_norm = self.optim_step() - if grad_norm is not None: - status["raw_grad_norm"] = grad_norm - - if self.record_memory: - self.save_memory_snapshot(global_step, local_step) + def all_reduce_metrics(self, status: Dict[str, float]) -> Dict[str, float]: + """ + All-reduce metrics across data parallel workers. + """ + return self.strategy.all_reduce(status) - # Record status for the last microbatch in the minibatch. - record_status(status) + def get_lr(self) -> float: + """ + Get current learning rate from scheduler. + """ + return self.scheduler.get_last_lr()[0] + def barrier(self) -> None: + """ + Synchronization barrier across all workers. + """ torch.distributed.barrier() - # not needed beyond status logging - all_metrics.pop("response_length", None) - - status_mean = reduce_metrics(all_metrics) - status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch - - # should return an `TrainingOutputBatch` - output = TrainingOutputBatch() - output.metadata = {"train_status": status_mean} - return output def save_checkpoint(self, ckpt_dir: Path, tokenizer=None): self.strategy.save_checkpoint( @@ -897,21 +889,60 @@ def __init__(self, **kwargs): def _normalize_mini_batch_size(self): """ - Normalize batch sizes based on device mesh and generation parameters. + Initialize micro batch tracking for gradient accumulation. + + The worker no longer needs to know mini batch size - it processes whatever + batch it receives, breaking it into micro batches. Gradient scaling happens + at optim_step time based on how many micro batches were accumulated. + + TODO: Rename to _init_gradient_accumulation_state once Megatron no longer + requires mini-batch normalization in its override. The name is kept for + backwards compatibility with Megatron which still does actual normalization. """ if not hasattr(self, "mesh_rank") or self.mesh_rank is None: raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()") - dp_size = self.mesh_rank.dp_size - self.critic_mini_batch_size_per_gpu = ( - self.cfg.trainer.critic_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size - ) + # Track micro batches for gradient scaling at optim_step + self._micro_batches_accumulated = 0 - def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]: + def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: """ - Perform the forward and backward pass for one micro-batch. + Perform forward and backward passes for a batch, handling micro-batching internally. + + The batch is split into micro batches based on micro_train_batch_size_per_gpu. + Gradients accumulate across micro batches. Gradient scaling happens at optim_step. + + Args: + data: TrainingInputBatch (already DP-sharded by WorkerDispatch/MeshDispatch) + + Returns: + Aggregated metrics dict across all micro batches + """ + micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu + all_metrics = defaultdict(list) + + for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): + metrics = self._forward_backward_micro(micro_batch) + self._micro_batches_accumulated += 1 + for k, v in metrics.items(): + all_metrics[k].append(v) + + return reduce_metrics(dict(all_metrics)) + + def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: + """ + Perform forward and backward pass for one micro batch. + + Loss is NOT scaled here - gradient scaling happens at optim_step time. + + Args: + experience: Experience object for one micro batch + + Returns: + All-reduced metrics dict for this micro batch """ self.model.train() + experience.to_device(torch.cuda.current_device()) sequences = experience.sequences @@ -923,7 +954,7 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> with torch.autocast(dtype=torch.bfloat16, device_type="cuda"): # critic loss - values, output = self.model( + values, _ = self.model( sequences, num_actions=num_actions, attention_mask=attention_mask, @@ -937,25 +968,63 @@ def forward_backward(self, experience: Experience, microbatch_weight: float) -> config=self.cfg.trainer.algorithm, loss_mask=loss_mask, ) - loss = loss * microbatch_weight + # NO loss scaling here - gradient scaling happens at optim_step self.strategy.backward(loss, self.model, self.optimizer) status = { "critic_loss": loss.item(), "values_mean": masked_mean(values, loss_mask).item(), "values_clipfrac": clipfrac, + "critic_lr": self.scheduler.get_last_lr()[0], } + + # All-reduce metrics across DP workers + status = self.strategy.all_reduce(status) + return status def optim_step(self) -> float: """ - Perform optimizer step and return the gradient norm. + Scale gradients by 1/micro_batches_accumulated, perform optimizer step, and reset counter. + + Returns: + The gradient norm (before scaling, after clipping) """ + # Scale accumulated gradients by 1/N to get correct average + if self._micro_batches_accumulated > 0: + scale = 1.0 / self._micro_batches_accumulated + for param in self.model.parameters(): + if param.grad is not None: + param.grad.mul_(scale) + + # Perform optimizer step (includes gradient clipping) grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="critic") + + # Reset counter for next accumulation cycle + self._micro_batches_accumulated = 0 + if grad_norm is not None: grad_norm = grad_norm.detach().cpu().item() return grad_norm + def all_reduce_metrics(self, status: Dict[str, float]) -> Dict[str, float]: + """ + All-reduce metrics across data parallel workers. + """ + return self.strategy.all_reduce(status) + + def get_lr(self) -> float: + """ + Get current learning rate from scheduler. + """ + return self.scheduler.get_last_lr()[0] + + def barrier(self) -> None: + """ + Synchronization barrier across all workers. + """ + torch.distributed.barrier() + def _forward_micro_batch( self, micro_batch: TrainingInputBatch, @@ -989,70 +1058,6 @@ def save_hf_model(self, export_dir: str, tokenizer): tokenizer=tokenizer, ) - def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: - global_step = train_data.metadata["global_step"] - minibatch_iterator = BatchIterator( - train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False - ) - - all_metrics = defaultdict(list) - num_minibatches = len(minibatch_iterator) - local_step = 0 - - def record_status(status: Dict[str, float]): - status["critic_lr"] = self.scheduler.get_last_lr()[0] - - # for DP - # TODO (sumanthrh): this assumes all workers are data parallel. - # We should get more accurate metrics with seq parallel or TP. - # There are metrics like entropy where we get average over local data size - status = self.strategy.all_reduce(status) - - for k, v in status.items(): - all_metrics[k].append(v) - minibatch_pbar.set_postfix(status) - - for epoch in range(self.cfg.trainer.update_epochs_per_batch): - minibatch_pbar = tqdm( - minibatch_iterator, - desc=f"Critic Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", - disable=not self.strategy.is_rank_0(), - ) - for minibatch in minibatch_pbar: - microbatch_iterator = BatchIterator( - minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False - ) - num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches - - for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = BatchIterator.batch_to_experience(microbatch) - status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) - - if microbatch_idx < num_microbatches - 1: - if self.record_memory: - self.save_memory_snapshot(global_step, local_step) - record_status(status) - - local_step += 1 - - grad_norm = self.optim_step() - if grad_norm is not None: - status["raw_grad_norm"] = grad_norm - - if self.record_memory: - self.save_memory_snapshot(global_step, local_step) - record_status(status) - - torch.distributed.barrier() - - status_mean = reduce_metrics(all_metrics) - status_mean["critic_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch - - output = TrainingOutputBatch() - output.metadata = {"train_status": status_mean} - return output - def save_checkpoint(self, ckpt_dir: str, tokenizer=None): self.strategy.save_checkpoint( model=self.model, diff --git a/skyrl-train/skyrl_train/workers/worker_dispatch.py b/skyrl-train/skyrl_train/workers/worker_dispatch.py new file mode 100644 index 000000000..cff67e703 --- /dev/null +++ b/skyrl-train/skyrl_train/workers/worker_dispatch.py @@ -0,0 +1,291 @@ +""" +WorkerDispatch: Manages all actor groups with automatic offload/onload. + +Automatically handles GPU placement: +- Tracks which model is currently on GPU +- If colocation is enabled, offloads other models when one is requested + +The trainer interacts with the worker dispatch if all models are always on GPU. +""" + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import ray +from omegaconf import DictConfig + +from skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch +from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch +from skyrl_train.workers.worker import PPORayActorGroup + + +@dataclass +class GPUState: + """Tracks what's on GPU for a model.""" + + model_on_gpu: bool = False + optimizer_on_gpu: bool = False + + +class WorkerDispatch: + """ + Unified dispatch layer that manages all actor groups (policy, critic, ref). + + Handles automatic offload/onload when colocate_all=True. + """ + + def __init__( + self, + cfg: DictConfig, + policy_actor_group: PPORayActorGroup, + critic_actor_group: Optional[PPORayActorGroup] = None, + ref_actor_group: Optional[PPORayActorGroup] = None, + ): + self.cfg = cfg + self.colocate_all = cfg.trainer.placement.colocate_all + self.colocate_policy_ref = cfg.trainer.placement.colocate_policy_ref + + # Actor groups by name. + # TODO: Remove these role-specific identifiers. We will move to using model IDs and add support for generic models beyond these. + self._actor_groups: Dict[str, PPORayActorGroup] = {"policy": policy_actor_group} + if critic_actor_group is not None: + self._actor_groups["critic"] = critic_actor_group + if ref_actor_group is not None: + self._actor_groups["ref"] = ref_actor_group + + # GPU state tracking (only matters when colocated) + self._gpu_state: Dict[str, GPUState] = {name: GPUState() for name in self._actor_groups.keys()} + + def get_lcm_dp_size(self) -> int: + """Get LCM of all models' dp_size.""" + import math + + dp_size = self._actor_groups["policy"].actor_infos[0].rank.dp_size + if "critic" in self._actor_groups: + dp_size = math.lcm(dp_size, self._actor_groups["critic"].actor_infos[0].rank.dp_size) + if "ref" in self._actor_groups: + dp_size = math.lcm(dp_size, self._actor_groups["ref"].actor_infos[0].rank.dp_size) + return dp_size + + def _should_manage_offload(self, model: str) -> bool: + """Check if we need to manage offload for this model.""" + if self.colocate_all: + return True + if self.colocate_policy_ref and model in ("policy", "ref"): + return True + return False + + def _get_colocation_group(self, model: str) -> List[str]: + """Get which models share GPU with the given model.""" + if self.colocate_all: + return list(self._actor_groups.keys()) + elif self.colocate_policy_ref and model in ("policy", "ref"): + return [m for m in ["policy", "ref"] if m in self._actor_groups] + return [model] + + def _ensure_on_gpu(self, model: str, need_optimizer: bool = True, need_model: bool = True) -> None: + """Ensure model is on GPU, offloading others in same colocation group if needed.""" + if not self._should_manage_offload(model): + return + + if model not in self._actor_groups: + return + + group = self._get_colocation_group(model) + + # Offload others in the same colocation group + for other in group: + if other != model and other in self._actor_groups: + state = self._gpu_state[other] + if state.model_on_gpu or state.optimizer_on_gpu: + self._actor_groups[other].offload_to_cpu() + self._gpu_state[other] = GPUState() + + # Backload requested model + state = self._gpu_state[model] + needs_backload = (need_model and not state.model_on_gpu) or (need_optimizer and not state.optimizer_on_gpu) + + if needs_backload: + self._actor_groups[model].backload_to_gpu( + backload_optimizer=need_optimizer, + backload_model=need_model, + ) + if need_model: + self._gpu_state[model].model_on_gpu = True + if need_optimizer: + self._gpu_state[model].optimizer_on_gpu = True + + def _offload(self, model: str, offload_optimizer: bool = True, offload_model: bool = True) -> None: + """Offload model to CPU.""" + if not self._should_manage_offload(model): + return + + if model not in self._actor_groups: + return + + self._actor_groups[model].offload_to_cpu( + offload_optimizer=offload_optimizer, + offload_model=offload_model, + ) + + if offload_model: + self._gpu_state[model].model_on_gpu = False + if offload_optimizer: + self._gpu_state[model].optimizer_on_gpu = False + + def mark_all_offloaded(self) -> None: + """Mark all models as offloaded (call after build_models when colocate_all).""" + for model in self._actor_groups: + self._gpu_state[model] = GPUState() + + def forward(self, model: str, data: TrainingInputBatch) -> TrainingOutputBatch: + """Run inference forward pass. Only loads model (not optimizer).""" + self._ensure_on_gpu(model, need_optimizer=False, need_model=True) + + refs = self._actor_groups[model].async_run_ray_method("mesh", "forward", data=data) + results = ray.get(refs) + + output = concatenate_outputs_after_mesh_dispatch(self._actor_groups[model].actor_infos, results) + return output + + def forward_backward(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: + """Run forward/backward pass. Needs model + optimizer.""" + self._ensure_on_gpu(model, need_optimizer=True, need_model=True) + + refs = self._actor_groups[model].async_run_ray_method("mesh", "forward_backward", data) + statuses = ray.get(refs) + + self._save_memory_snapshot(model, "forward_backward") + return statuses[0] + + def optim_step(self, model: str) -> Optional[float]: + """Run optimizer step. Model should already be on GPU from forward_backward.""" + refs = self._actor_groups[model].async_run_ray_method("pass_through", "optim_step") + grad_norms = ray.get(refs) + + self._save_memory_snapshot(model, "optim_step") + return grad_norms[0] + + # TODO(tgriggs): Remove this when Megatron supports forward_backward and optim_step. + def ppo_train(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: + """Run full PPO training loop (for Megatron).""" + self._ensure_on_gpu(model, need_optimizer=True, need_model=True) + + refs = self._actor_groups[model].async_run_ray_method("mesh", "ppo_train", data) + statuses = ray.get(refs) + + return statuses[0].metadata["train_status"] + + def _save_memory_snapshot(self, model: str, tag: str) -> None: + """Save memory snapshot on workers.""" + ray.get( + self._actor_groups[model].async_run_ray_method("pass_through", "save_memory_snapshot", tag=f"{model}_{tag}") + ) + + def save_checkpoint(self, model: str, ckpt_dir: str, tokenizer=None) -> None: + """Save checkpoint for model.""" + self._ensure_on_gpu(model, need_optimizer=True, need_model=True) + + ray.get( + self._actor_groups[model].async_run_ray_method( + "pass_through", "save_checkpoint", ckpt_dir=ckpt_dir, tokenizer=tokenizer + ) + ) + + def load_checkpoint( + self, + model: str, + ckpt_dir: str, + load_optimizer_states: bool = True, + load_lr_scheduler_states: bool = True, + ) -> None: + """Load checkpoint for model.""" + self._ensure_on_gpu(model, need_optimizer=load_optimizer_states, need_model=True) + + ray.get( + self._actor_groups[model].async_run_ray_method( + "pass_through", + "load_checkpoint", + ckpt_dir=ckpt_dir, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + ) + ) + + def save_hf_model(self, model: str, export_dir: str, tokenizer) -> None: + """Save model in HuggingFace format.""" + self._ensure_on_gpu(model, need_optimizer=False, need_model=True) + + ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "save_hf_model", export_dir, tokenizer)) + + def init_model(self, model: str, model_path: str, num_training_steps: Optional[int] = None) -> None: + """Initialize model from path. Offloads others in colocation group first.""" + # Offload others in colocation group before init + if self._should_manage_offload(model): + group = self._get_colocation_group(model) + for other in group: + if other != model and other in self._actor_groups: + state = self._gpu_state[other] + if state.model_on_gpu or state.optimizer_on_gpu: + self._actor_groups[other].offload_to_cpu() + self._gpu_state[other] = GPUState() + + kwargs = {"model_path": model_path} + if num_training_steps is not None: + kwargs["num_training_steps"] = num_training_steps + + ray.get(self._actor_groups[model].async_init_model(**kwargs)) + + # After init, model is on GPU + self._gpu_state[model].model_on_gpu = True + self._gpu_state[model].optimizer_on_gpu = model != "ref" # ref has no optimizer + + def init_weight_sync_state(self, inference_engine_client) -> None: + """Initialize weight sync state for policy model.""" + ray.get( + self._actor_groups["policy"].async_run_ray_method( + "pass_through", "init_weight_sync_state", inference_engine_client + ) + ) + + def broadcast_to_inference_engines(self, inference_engine_client) -> None: + """Broadcast policy weights to inference engines.""" + ray.get( + self._actor_groups["policy"].async_run_ray_method( + "pass_through", "broadcast_to_inference_engines", inference_engine_client + ) + ) + + def prepare_for_weight_sync(self) -> None: + """Prepare for weight sync: ensure policy model is on GPU, offload optimizer.""" + if not self.colocate_all: + return + # Ensure policy model is on GPU (will offload others in colocation group) + self._ensure_on_gpu("policy", need_optimizer=False, need_model=True) + # Offload optimizer if it's on GPU + if self._gpu_state["policy"].optimizer_on_gpu: + self._offload("policy", offload_optimizer=True, offload_model=False) + + def finish_weight_sync(self) -> None: + """Finish weight sync: offload model.""" + if not self.colocate_all: + return + self._offload("policy", offload_optimizer=False, offload_model=True) + + def empty_cache(self, model: Optional[str] = None) -> None: + """Empty GPU cache for model(s).""" + if model is not None: + ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "empty_cache")) + else: + refs = [] + for group in self._actor_groups.values(): + refs.extend(group.async_run_ray_method("pass_through", "empty_cache")) + ray.get(refs) + + def get_node_ids(self) -> List[str]: + """Get unique node IDs from all actor groups.""" + all_node_ids = [] + for group in self._actor_groups.values(): + node_ids = ray.get(group.async_run_ray_method("pass_through", "get_ray_node_id")) + all_node_ids.extend(node_ids) + return list(set(all_node_ids)) diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 897d032ea..43b99d91a 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -37,9 +37,11 @@ def __len__(self): def __iter__(self): return self - def __next__(self) -> TrainingInputBatch: + def __next__(self) -> Experience: try: - return next(self._iter) + batch = next(self._iter) + exp = self.batch_to_experience(batch) + return exp except StopIteration: self._iter = iter(self._chunks) raise StopIteration diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 599f8a3f4..a704278e6 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -2,21 +2,21 @@ uv run --isolated --extra dev pytest tests/cpu/test_trainer.py """ -import torch +from unittest.mock import MagicMock, patch + +import numpy as np import pytest +import torch from jaxtyping import Float, Integer from pytest import approx -from unittest.mock import MagicMock, patch - - +from skyrl_train.config.utils import get_default_config from skyrl_train.distributed.dispatch import MeshRank from skyrl_train.trainer import RayPPOTrainer from skyrl_train.training_batch import TrainingInputBatch -import numpy as np -from skyrl_train.workers.worker import PolicyWorkerBase, CriticWorkerBase -from skyrl_train.workers.worker_utils import BatchIterator from skyrl_train.utils.utils import validate_batch_sizes -from skyrl_train.config.utils import get_default_config +from skyrl_train.workers.worker import CriticWorkerBase, PolicyWorkerBase +from skyrl_train.workers.worker_utils import BatchIterator + from tests.cpu.util import example_dummy_config @@ -170,7 +170,14 @@ def test_calc_advantages_and_returns(mock_compute_adv_and_ret, dummy_config): def test_normalize_mini_batch_size(): - """Test the _normalize_mini_batch_size method with various configurations.""" + """Test the _normalize_mini_batch_size method initializes micro batch tracking. + + Workers don't need to know mini batch sizes per GPU. + They receive batches from the trainer and split them into micro batches. + _normalize_mini_batch_size only initializes micro batch tracking for gradient scaling. + + # TODO: Update naming once Megatron is updated to not be aware of mini batch sizes. + """ # Create minimal worker instances for testing class TestPolicyWorker(PolicyWorkerBase): @@ -199,16 +206,10 @@ def backload_to_gpu(self, non_blocking=True): def _forward_micro_batch(self, micro_batch): pass - def create_policy_worker_with_config( - train_batch_size, policy_mini_batch_size, micro_train_batch_size_per_gpu, n_samples_per_prompt, dp_size - ): + def create_policy_worker_with_config(dp_size): """Helper to create policy worker with specific config.""" cfg = get_default_config() - cfg.trainer.train_batch_size = train_batch_size - cfg.trainer.policy_mini_batch_size = policy_mini_batch_size - cfg.trainer.micro_train_batch_size_per_gpu = micro_train_batch_size_per_gpu cfg.trainer.algorithm.policy_loss_type = "regular" - cfg.generator.n_samples_per_prompt = n_samples_per_prompt worker = TestPolicyWorker( cfg=cfg, @@ -225,15 +226,9 @@ def create_policy_worker_with_config( return worker - def create_critic_worker_with_config( - train_batch_size, critic_mini_batch_size, micro_train_batch_size_per_gpu, n_samples_per_prompt, dp_size - ): + def create_critic_worker_with_config(dp_size): """Helper to create critic worker with specific config.""" cfg = get_default_config() - cfg.trainer.train_batch_size = train_batch_size - cfg.trainer.critic_mini_batch_size = critic_mini_batch_size - cfg.trainer.micro_train_batch_size_per_gpu = micro_train_batch_size_per_gpu - cfg.generator.n_samples_per_prompt = n_samples_per_prompt worker = TestCriticWorker( cfg=cfg, @@ -250,66 +245,29 @@ def create_critic_worker_with_config( return worker - # Test Case 1: Basic valid configuration for PolicyWorker - policy_worker = create_policy_worker_with_config( - train_batch_size=128, - policy_mini_batch_size=16, - micro_train_batch_size_per_gpu=2, - n_samples_per_prompt=2, - dp_size=4, - ) + # Test Case 1: PolicyWorker initializes _micro_batches_accumulated + policy_worker = create_policy_worker_with_config(dp_size=4) policy_worker._normalize_mini_batch_size() - expected_policy_mini_batch_size_per_gpu = (16 * 2) // 4 # 8 - assert policy_worker.policy_mini_batch_size_per_gpu == expected_policy_mini_batch_size_per_gpu + assert hasattr(policy_worker, "_micro_batches_accumulated") + assert policy_worker._micro_batches_accumulated == 0 - # Test Case 2: Basic valid configuration for CriticWorker - critic_worker = create_critic_worker_with_config( - train_batch_size=128, - critic_mini_batch_size=8, - micro_train_batch_size_per_gpu=2, - n_samples_per_prompt=2, - dp_size=4, - ) + # Test Case 2: CriticWorker initializes _micro_batches_accumulated + critic_worker = create_critic_worker_with_config(dp_size=4) critic_worker._normalize_mini_batch_size() - expected_critic_mini_batch_size_per_gpu = (8 * 2) // 4 # 4 - assert critic_worker.critic_mini_batch_size_per_gpu == expected_critic_mini_batch_size_per_gpu + assert hasattr(critic_worker, "_micro_batches_accumulated") + assert critic_worker._micro_batches_accumulated == 0 # Test Case 3: Single GPU (dp_size=1) for PolicyWorker - policy_worker = create_policy_worker_with_config( - train_batch_size=32, - policy_mini_batch_size=8, - micro_train_batch_size_per_gpu=4, - n_samples_per_prompt=1, - dp_size=1, - ) + policy_worker = create_policy_worker_with_config(dp_size=1) policy_worker._normalize_mini_batch_size() - expected_policy_mini_batch_size_per_gpu = (8 * 1) // 1 # 8 - assert policy_worker.policy_mini_batch_size_per_gpu == expected_policy_mini_batch_size_per_gpu - - # Test Case 4: High n_samples_per_prompt for CriticWorker - critic_worker = create_critic_worker_with_config( - train_batch_size=256, - critic_mini_batch_size=32, - micro_train_batch_size_per_gpu=8, - n_samples_per_prompt=4, - dp_size=2, - ) - critic_worker._normalize_mini_batch_size() - - expected_critic_mini_batch_size_per_gpu = (32 * 4) // 2 # 64 - assert critic_worker.critic_mini_batch_size_per_gpu == expected_critic_mini_batch_size_per_gpu + assert hasattr(policy_worker, "_micro_batches_accumulated") + assert policy_worker._micro_batches_accumulated == 0 - # Test Case 5: Error case - mesh_rank not initialized - policy_worker_no_mesh = create_policy_worker_with_config( - train_batch_size=128, - policy_mini_batch_size=16, - micro_train_batch_size_per_gpu=2, - n_samples_per_prompt=1, - dp_size=4, - ) + # Test Case 4: Error case - mesh_rank not initialized + policy_worker_no_mesh = create_policy_worker_with_config(dp_size=4) policy_worker_no_mesh.mesh_rank = None with pytest.raises(RuntimeError, match="mesh_rank must be initialized"): @@ -486,8 +444,13 @@ def create_test_config( validate_batch_sizes(cfg) -def test_ppo_train_batch_calculations(): - """Test the key batch calculations and control flow in ppo_train methods.""" +def test_forward_backward_batch_calculations(): + """Test the key batch calculations and control flow in forward_backward methods. + + FSDP workers use the forward_backward + optim_step pattern: + - forward_backward handles micro-batching internally and accumulates gradients + - optim_step scales gradients by 1/num_accumulated and takes optimizer step + """ # Create test configuration cfg = get_default_config() @@ -526,18 +489,13 @@ def create_test_worker(worker_class): master_port=12345, sequence_parallel_size=1, ) - # Set appropriate mini batch size per gpu based on worker type - if worker_class == PolicyWorkerBase: - worker.policy_mini_batch_size_per_gpu = 6 # Should result in 3 micro batches per mini batch - elif worker_class == CriticWorkerBase: - worker.critic_mini_batch_size_per_gpu = 6 # Should result in 3 micro batches per mini batch # Mock dependencies worker.strategy = MagicMock() worker.strategy.is_rank_0.return_value = False # Disable progress bars worker.strategy.all_reduce.return_value = {"loss": 0.5, "lr": 1e-4} - # Always set model for all worker types (policy/critic need this for ppo_train) + # Always set model for all worker types worker.model = MagicMock() return worker @@ -545,97 +503,71 @@ def create_test_worker(worker_class): # Test PolicyWorkerBase policy_worker = create_test_worker(PolicyWorkerBase) - # Mock forward_backward and optim_step to track calls and verify accumulation behavior - policy_forward_backward_calls = [] + # Initialize _micro_batches_accumulated (normally done in _normalize_mini_batch_size) + policy_worker._micro_batches_accumulated = 0 + + # Mock _forward_backward_micro to track calls + policy_forward_backward_micro_calls = [] - def mock_policy_forward_backward(experience, microbatch_weight): - policy_forward_backward_calls.append({"microbatch_weight": microbatch_weight}) + def mock_policy_forward_backward_micro(experience): + policy_forward_backward_micro_calls.append(experience) return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length} - policy_worker.forward_backward = mock_policy_forward_backward - policy_worker.optim_step = MagicMock(return_value=None) - policy_worker.scheduler = MagicMock() - policy_worker.scheduler.get_last_lr.return_value = [1e-4] + policy_worker._forward_backward_micro = mock_policy_forward_backward_micro policy_worker.record_memory = False - # Calculate expected values based on new accumulation logic + # Calculate expected values dataloader = BatchIterator( dummy_databatch, sample_batch_size=cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) - total_micro_batches = len(dataloader) # Should be 6 - micro_batches_per_mini_batch = ( - policy_worker.policy_mini_batch_size_per_gpu // cfg.trainer.micro_train_batch_size_per_gpu - ) # 6 // 2 = 3 - # New logic: accumulation_steps = micro_batches_per_mini_batch (accumulate within mini-batch) - expected_accumulation_steps = micro_batches_per_mini_batch # Should be 3 - expected_microbatch_weight = 1.0 / expected_accumulation_steps - - # Run policy ppo_train with minimal mocking - with ( - patch("torch.distributed.barrier"), - patch("tqdm.tqdm", side_effect=lambda x, **kwargs: x), - ): # Disable progress bar - result = policy_worker.ppo_train(dummy_databatch) + expected_micro_batches = len(dataloader) # Should be 6 + + # Run forward_backward + with (patch("torch.distributed.barrier"),): + result = policy_worker.forward_backward(dummy_databatch) # Verify Policy Worker Results assert ( - len(policy_forward_backward_calls) == total_micro_batches - ), f"PolicyWorker: Expected {total_micro_batches} forward_backward calls, got {len(policy_forward_backward_calls)}" + len(policy_forward_backward_micro_calls) == expected_micro_batches + ), f"PolicyWorker: Expected {expected_micro_batches} _forward_backward_micro calls, got {len(policy_forward_backward_micro_calls)}" - # Verify accumulation_steps are consistent (should equal micro_batches_per_mini_batch) - for call in policy_forward_backward_calls: - assert ( - call["microbatch_weight"] == expected_microbatch_weight - ), f"PolicyWorker: Expected microbatch_weight={expected_microbatch_weight}, got {call['microbatch_weight']}" + # Verify _micro_batches_accumulated is set correctly + assert policy_worker._micro_batches_accumulated == expected_micro_batches # Verify result structure - assert "train_status" in result.metadata - train_status = result.metadata["train_status"] - assert "policy_update_steps" in train_status + assert isinstance(result, dict) + assert "policy_loss" in result - # Verify policy_update_steps calculation (should be total_calls / accumulation_steps) - expected_policy_update_steps_normalized = len(policy_forward_backward_calls) / expected_accumulation_steps - assert train_status["policy_update_steps"] == expected_policy_update_steps_normalized - - # Test CriticWorkerBase with same accumulation logic + # Test CriticWorkerBase with same pattern critic_worker = create_test_worker(CriticWorkerBase) - # Mock forward_backward and optim_step for critic - critic_forward_backward_calls = [] + # Initialize _micro_batches_accumulated (normally done in _normalize_mini_batch_size) + critic_worker._micro_batches_accumulated = 0 - def mock_critic_forward_backward(experience, microbatch_weight): - critic_forward_backward_calls.append({"microbatch_weight": microbatch_weight}) - return {"critic_loss": 0.3, "values": 1.0} + # Mock _forward_backward_micro for critic + critic_forward_backward_micro_calls = [] - critic_worker.forward_backward = mock_critic_forward_backward - critic_worker.optim_step = MagicMock(return_value=None) - critic_worker.scheduler = MagicMock() - critic_worker.scheduler.get_last_lr.return_value = [1e-4] + def mock_critic_forward_backward_micro(experience): + critic_forward_backward_micro_calls.append(experience) + return {"critic_loss": 0.3, "values_mean": 1.0} - # Run critic ppo_train - with ( - patch("torch.distributed.barrier"), - patch("tqdm.tqdm", side_effect=lambda x, **kwargs: x), - patch("torch.cuda.empty_cache"), - ): - result = critic_worker.ppo_train(dummy_databatch) + critic_worker._forward_backward_micro = mock_critic_forward_backward_micro + + # Run forward_backward for critic + with (patch("torch.distributed.barrier"),): + result = critic_worker.forward_backward(dummy_databatch) # Verify Critic Worker Results assert ( - len(critic_forward_backward_calls) == total_micro_batches - ), f"CriticWorker: Expected {total_micro_batches} forward_backward calls, got {len(critic_forward_backward_calls)}" + len(critic_forward_backward_micro_calls) == expected_micro_batches + ), f"CriticWorker: Expected {expected_micro_batches} _forward_backward_micro calls, got {len(critic_forward_backward_micro_calls)}" - # Verify accumulation_steps are consistent for critic (should equal micro_batches_per_mini_batch) - for call in critic_forward_backward_calls: - assert ( - call["microbatch_weight"] == expected_microbatch_weight - ), f"CriticWorker: Expected microbatch_weight={expected_microbatch_weight}, got {call['microbatch_weight']}" + # Verify _micro_batches_accumulated is set correctly + assert critic_worker._micro_batches_accumulated == expected_micro_batches # Verify result structure for critic - assert "train_status" in result.metadata - train_status = result.metadata["train_status"] - assert "critic_update_steps" in train_status - assert train_status["critic_update_steps"] == len(critic_forward_backward_calls) / expected_accumulation_steps + assert isinstance(result, dict) + assert "critic_loss" in result def test_validate_batch_sizes_lcm_dp_requirement(): diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 592a0518b..928bd9df4 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -531,9 +531,14 @@ async def test_megatron_train( cfg=cfg, ) + # FSDP uses forward_backward + optim_step instead of ppo_train batch.metadata["global_step"] = 0 - results_fsdp = ray.get(actor_group.async_run_ray_method("pass_through", "ppo_train", batch)) - results_fsdp = [results_fsdp[i].metadata["train_status"] for i in range(len(results_fsdp))] + results_fsdp = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + # Get learning rate from worker + lr_results = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) + for i, result in enumerate(results_fsdp): + result["policy_lr"] = lr_results[i] print("megatron results: ", results_megatron[0]) print("\n\n") diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py deleted file mode 100644 index c9880a017..000000000 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Tests for ppo_train method in worker classes. - -Run with: -uv run --isolated --extra dev pytest tests/gpu/gpu_ci/test_ppo_train.py -""" - -import pytest -import ray -from omegaconf import DictConfig - -from tests.gpu.utils import init_worker_with_type, make_dummy_training_batch, get_test_actor_config, validate_cfg - - -@pytest.fixture -def cfg() -> DictConfig: - """Get test configuration with minimal settings for fast testing.""" - cfg = get_test_actor_config() - - cfg.trainer.update_epochs_per_batch = 1 - cfg.trainer.micro_train_batch_size_per_gpu = 1 - cfg.trainer.policy_mini_batch_size = 2 - cfg.generator.n_samples_per_prompt = 1 - cfg.trainer.placement.policy_num_gpus_per_node = 2 - cfg.trainer.logger = "console" - cfg.generator.inference_engine_tensor_parallel_size = 2 - validate_cfg(cfg) - - return cfg - - -@pytest.mark.parametrize("use_entropy_loss, use_kl_loss", [(False, False), (True, True), (True, False), (False, True)]) -def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_kl_loss): - """ - Test that ppo_train runs and returns correct structure. - - This test validates: - - ppo_train method executes successfully - - Returns TrainingOutputBatch with correct metadata structure - - Contains expected training metrics - """ - try: - cfg.trainer.strategy = "fsdp2" # Strategy logic is not tested here. - if use_entropy_loss: - cfg.trainer.algorithm.use_entropy_loss = True - cfg.trainer.algorithm.entropy_loss_coef = 0.01 - if use_kl_loss: - cfg.trainer.algorithm.use_kl_loss = True - cfg.trainer.algorithm.kl_loss_coef = 0.001 - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, - cfg=cfg, - ) - - train_data = make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4) - train_data.metadata["global_step"] = 0 - - # Run ppo_train - results = ray.get(actor_group.async_run_ray_method("pass_through", "ppo_train", train_data)) - assert len(results) == cfg.trainer.placement.policy_num_gpus_per_node, "Should get result from each GPU" - - result = results[0] # Check first worker result - assert hasattr(result, "metadata"), "Result should have metadata attribute" - assert "train_status" in result.metadata, "Should have train_status in metadata" - - train_status = result.metadata["train_status"] - - # Validate expected training metrics are present - expected_metrics = [ - "policy_loss", - "policy_update_steps", - "policy_lr", - "ppo_clip_ratio", - "policy_entropy", - "final_loss", - ] - - for metric in expected_metrics: - assert metric in train_status, f"Should have {metric} in train_status" - assert isinstance(train_status[metric], (int, float)), f"{metric} should be numeric" - - # Simple check for metric values - assert train_status["policy_update_steps"] > 0, "Should have completed at least one update step" - assert train_status["policy_lr"] > 0, "Should have positive learning rate" - - finally: - ray.shutdown() - - -def test_ppo_train_critic_worker(ray_init_fixture, cfg): - """ - Test that ppo_train works for critic worker as well. - """ - try: - cfg.trainer.strategy = "fsdp2" # Strategy logic is not tested here. - - actor_group = init_worker_with_type( - "critic", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, - cfg=cfg, - ) - - # Create training batch directly - train_data = make_dummy_training_batch(batch_size=2, seq_len=10, num_actions=4) - train_data.metadata["global_step"] = 0 - - # Run ppo_train - results = ray.get(actor_group.async_run_ray_method("pass_through", "ppo_train", train_data)) - - result = results[0] - assert hasattr(result, "metadata"), "Result should have metadata attribute" - assert "train_status" in result.metadata, "Should have train_status in metadata" - - train_status = result.metadata["train_status"] - - # Validate critic-specific metrics - expected_critic_metrics = ["critic_loss", "critic_update_steps", "values_mean", "critic_lr"] - - for metric in expected_critic_metrics: - assert metric in train_status, f"Should have {metric} in critic train_status" - assert isinstance(train_status[metric], (int, float)), f"{metric} should be numeric" - - assert train_status["critic_update_steps"] > 0, "Should have completed at least one critic update step" - - print(f"Critic ppo_train completed successfully with metrics: {train_status}") - finally: - ray.shutdown() - - -@pytest.mark.parametrize( - "test_id, micro_train_batch_size_per_gpu, policy_mini_batch_size, n_samples_per_prompt, update_epochs_per_batch, batch_size, expected_optimizer_steps", - [ - ("accumulation_calculation", 2, 8, 2, 1, 8, 1), - ("optimizer_stepping", 1, 8, 1, 1, 12, 3), - ("multiple_epochs", 1, 4, 1, 3, 6, 9), - ], - ids=["accumulation_calculation", "optimizer_stepping", "multiple_epochs"], -) -def test_gradient_accumulation_scenarios( - ray_init_fixture, - test_id, - micro_train_batch_size_per_gpu, - policy_mini_batch_size, - n_samples_per_prompt, - update_epochs_per_batch, - batch_size, - expected_optimizer_steps, -): - """ - Test that gradient accumulation and optimizer stepping work correctly across various scenarios. - - Validates: - - Correct calculation of accumulation steps based on configuration. - - Optimizer stepping at correct intervals for single and multiple epochs. - - Consistent behavior across different batch and minibatch sizes. - """ - try: - cfg = get_test_actor_config() - cfg.trainer.strategy = "fsdp2" # Strategy logic is not tested here. - cfg.trainer.placement.policy_num_gpus_per_node = 2 - - # Set scenario-specific config - cfg.trainer.micro_train_batch_size_per_gpu = micro_train_batch_size_per_gpu - cfg.trainer.policy_mini_batch_size = policy_mini_batch_size - cfg.generator.n_samples_per_prompt = n_samples_per_prompt - cfg.trainer.update_epochs_per_batch = update_epochs_per_batch - cfg.generator.inference_engine_tensor_parallel_size = 2 - - # For logging and assertions, calculate expected accumulation steps - dp_size = cfg.trainer.placement.policy_num_gpus_per_node - policy_mini_batch_size_per_gpu = (policy_mini_batch_size * n_samples_per_prompt) // dp_size - # If micro_train_batch_size_per_gpu is 0, this indicates an issue in configuration, but for safety: - accumulation_steps = ( - policy_mini_batch_size_per_gpu // micro_train_batch_size_per_gpu - if micro_train_batch_size_per_gpu > 0 - else 1 - ) - if accumulation_steps == 0: - accumulation_steps = 1 # Should not be 0, must step at least once. - - actor_group = init_worker_with_type( - "policy", - shared_pg=None, - colocate_all=False, - num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, - cfg=cfg, - ) - - train_data = make_dummy_training_batch(batch_size=batch_size, seq_len=10, num_actions=4) - train_data.metadata["global_step"] = 0 - - result = ray.get(actor_group.async_run_ray_method("pass_through", "ppo_train", train_data))[0] - - train_status = result.metadata["train_status"] - actual_optimizer_steps = train_status["policy_update_steps"] - - assert actual_optimizer_steps == expected_optimizer_steps, ( - f"Test '{test_id}' failed: Expected {expected_optimizer_steps} optimizer steps, got {actual_optimizer_steps}. " - f"Config: micro_batch={micro_train_batch_size_per_gpu}, mini_batch={policy_mini_batch_size}, " - f"n_samples={n_samples_per_prompt}, epochs={update_epochs_per_batch}, " - f"data_batch_size={batch_size}, accumulation_steps={accumulation_steps}" - ) - - print(f"Gradient accumulation scenario '{test_id}' PASSED:") - print( - f" - Config: micro_batch={micro_train_batch_size_per_gpu}, mini_batch={policy_mini_batch_size}, " - f"n_samples={n_samples_per_prompt}, epochs={update_epochs_per_batch}" - ) - print(f" - Data batch size: {batch_size}") - print(f" - Expected accumulation steps: {accumulation_steps}") - print(f" - Expected optimizer steps: {expected_optimizer_steps}") - print(f" - Actual optimizer steps: {actual_optimizer_steps}") - finally: - ray.shutdown() diff --git a/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py b/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py index a3669ff37..7918bb4b0 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py @@ -17,7 +17,7 @@ from transformers import AutoTokenizer from skyrl_train.utils.utils import print_mem -from tests.gpu.utils import init_worker_with_type, make_dummy_experience, get_model_logits_from_actor, validate_cfg +from tests.gpu.utils import init_worker_with_type, make_dummy_training_batch, get_model_logits_from_actor, validate_cfg from skyrl_train.entrypoints.main_base import config_dir MODEL_NAME = "Qwen/Qwen3-0.6B" @@ -28,7 +28,7 @@ def run_one_training_step( actor_group, strategy, - experience=None, + data=None, megatron_batch=None, ): """Run forward_backward + optim_step to perform one training step.""" @@ -36,8 +36,8 @@ def run_one_training_step( assert megatron_batch is not None, "Megatron requires a TrainingInputBatch for ppo_train" return ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", megatron_batch)) else: - assert experience is not None, f"{strategy} requires an Experience for forward_backward" - ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", experience, 1)) + assert data is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=data)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) @@ -92,13 +92,13 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) checkpoint_dir = None - # Create dummy experiences for training steps - dummy_experience_1 = make_dummy_experience() # First training step - dummy_experience_2 = make_dummy_experience() # Second training step + # Create dummy training batches for training steps + dp_size = actor_group.actor_infos[0].rank.dp_size + dummy_batch_1 = make_dummy_training_batch(batch_size=dp_size) # First training step + dummy_batch_2 = make_dummy_training_batch(batch_size=dp_size) # Second training step - # Ensure the second experience is different from the first - for i, seq in enumerate(dummy_experience_2.sequences): - dummy_experience_2.sequences[i] = torch.randint(100, 200, seq.shape, device=seq.device) + # Ensure the second batch is different from the first + dummy_batch_2["sequences"] = torch.randint(100, 200, dummy_batch_2["sequences"].shape, device="cpu") # For Megatron, build training batches and reuse the second one pre/post checkpoint resume if "megatron" in strategy: @@ -115,7 +115,7 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): run_one_training_step( actor_group, strategy, - experience=dummy_experience_1, + data=dummy_batch_1, megatron_batch=train_batch_1, ) @@ -161,7 +161,7 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): run_one_training_step( actor_group, strategy, - experience=dummy_experience_2, + data=dummy_batch_2, megatron_batch=train_batch_2, ) @@ -181,7 +181,7 @@ def test_save_load_checkpoint(ray_init_fixture, strategy, lora): run_one_training_step( actor_group, strategy, - experience=dummy_experience_2, + data=dummy_batch_2, megatron_batch=train_batch_2, ) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py index e81103434..aea3784be 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_training_step.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_training_step.py @@ -8,7 +8,7 @@ import hydra from omegaconf import DictConfig -from tests.gpu.utils import init_worker_with_type, make_dummy_experience, validate_cfg +from tests.gpu.utils import init_worker_with_type, make_dummy_training_batch, validate_cfg from skyrl_train.utils.utils import print_mem from skyrl_train.entrypoints.main_base import config_dir @@ -61,9 +61,11 @@ async def test_policy_forward_backward_and_optim_step(ray_init_fixture, cfg, pac cfg=cfg, ) - dummy_experience = make_dummy_experience() + # Create TrainingInputBatch - worker's forward_backward handles micro-batching internally + dp_size = actor_group.actor_infos[0].rank.dp_size + dummy_batch = make_dummy_training_batch(batch_size=dp_size) - results = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_experience, 1)) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) memory = ray.get(actor_group.async_run_ray_method("pass_through", "get_cuda_memory")) @@ -109,9 +111,11 @@ async def test_critic_forward_backward_and_optim_step(ray_init_fixture, cfg, pac cfg=cfg, ) - dummy_experience = make_dummy_experience() + # Create TrainingInputBatch - worker's forward_backward handles micro-batching internally + dp_size = actor_group.actor_infos[0].rank.dp_size + dummy_batch = make_dummy_training_batch(batch_size=dp_size) - results = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_experience, 1)) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) for result in results: diff --git a/skyrl-train/tests/gpu/gpu_ci/test_worker_dispatch_offload.py b/skyrl-train/tests/gpu/gpu_ci/test_worker_dispatch_offload.py new file mode 100644 index 000000000..7418d5c01 --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/test_worker_dispatch_offload.py @@ -0,0 +1,293 @@ +""" +Test WorkerDispatch automatic offload/onload with colocation policies. + +Run with: +uv run --isolated --extra dev -- pytest tests/gpu/gpu_ci/test_worker_dispatch_offload.py -v + +These tests validate that WorkerDispatch correctly manages GPU memory when +multiple models share the same GPU (colocate_all=True or colocate_policy_ref=True). +""" + +import ray +import pytest +import hydra +from omegaconf import DictConfig +from ray.util.placement_group import placement_group + +from tests.gpu.utils import make_dummy_training_batch, get_rank_0_memory +from skyrl_train.utils.utils import validate_cfg +from skyrl_train.utils import get_ray_pg_ready_with_timeout +from skyrl_train.entrypoints.main_base import config_dir +from skyrl_train.workers.worker_dispatch import WorkerDispatch, GPUState +from skyrl_train.workers.fsdp.fsdp_worker import PolicyWorker, RefWorker, CriticWorker +from skyrl_train.workers.worker import PPORayActorGroup + +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" + + +def get_test_config() -> DictConfig: + with hydra.initialize_config_dir(config_dir=config_dir): + cfg = hydra.compose(config_name="ppo_base_config") + + cfg.trainer.policy.model.path = MODEL_NAME + cfg.trainer.placement.policy_num_gpus_per_node = 1 + cfg.generator.inference_engine_tensor_parallel_size = 1 + cfg.trainer.placement.colocate_all = True + cfg.trainer.use_sample_packing = False + cfg.trainer.logger = "console" + cfg.trainer.strategy = "fsdp2" + cfg.trainer.ref.fsdp_config.cpu_offload = False + + validate_cfg(cfg) + return cfg + + +def init_colocated_actor_group( + worker_cls, + shared_pg, + cfg: DictConfig, +) -> PPORayActorGroup: + """Initialize an actor group that shares a placement group with others.""" + return PPORayActorGroup( + cfg, + num_nodes=1, + num_gpus_per_node=1, + ray_actor_type=worker_cls, + pg=shared_pg, + num_gpus_per_actor=0.4, # Share GPU + colocate_all=True, + sequence_parallel_size=cfg.trainer.policy.sequence_parallel_size, + record_memory=cfg.trainer.policy.record_memory, + ) + + +@pytest.mark.asyncio +async def test_colocate_all_only_one_model_on_gpu(ray_init_fixture): + """ + Test that with colocate_all=True, only one model is on GPU at a time. + + Scenario: + 1. Initialize policy and ref on shared GPU + 2. Call dispatch.forward("ref", ...) - ref should be on GPU, policy offloaded + 3. Call dispatch.forward_backward("policy", ...) - policy on GPU, ref offloaded + 4. Verify memory drops when switching (indicates offload happened) + """ + cfg = get_test_config() + + try: + # Create shared placement group + pg = placement_group([{"GPU": 1, "CPU": 2}], strategy="PACK") + get_ray_pg_ready_with_timeout(pg, timeout=30) + + # Initialize both actor groups on shared GPU + policy_group = init_colocated_actor_group(PolicyWorker, pg, cfg) + ref_group = init_colocated_actor_group(RefWorker, pg, cfg) + + # Init models - after init, models are on GPU + ray.get(policy_group.async_init_model(cfg.trainer.policy.model.path)) + ray.get(ref_group.async_init_model(cfg.trainer.policy.model.path)) + + # Create dispatch with colocate_all=True + dispatch = WorkerDispatch( + cfg, + policy_actor_group=policy_group, + ref_actor_group=ref_group, + ) + + # Mark both as on GPU after init + dispatch._gpu_state["policy"] = GPUState(model_on_gpu=True, optimizer_on_gpu=True) + dispatch._gpu_state["ref"] = GPUState(model_on_gpu=True, optimizer_on_gpu=False) + + # Manually offload both to start from clean state + policy_group.offload_to_cpu() + ref_group.offload_to_cpu() + dispatch.mark_all_offloaded() + + dp_size = policy_group.actor_infos[0].rank.dp_size + dummy_batch = make_dummy_training_batch(batch_size=dp_size) + + # === Test 1: Load ref model === + dispatch.forward("ref", dummy_batch) + + # Verify state tracking + assert dispatch._gpu_state["ref"].model_on_gpu, "ref should be marked on GPU" + assert not dispatch._gpu_state["policy"].model_on_gpu, "policy should be marked offloaded" + + # Verify ref memory increased (measure from ref_group since it's a separate actor) + ref_mem_after_load = get_rank_0_memory(ref_group, "After ref forward") + assert ref_mem_after_load > 1e8, f"Ref model should use significant memory: {ref_mem_after_load}" + + # === Test 2: Switch to policy (should offload ref) === + dispatch.forward_backward("policy", dummy_batch) + + # Verify state tracking + assert dispatch._gpu_state["policy"].model_on_gpu, "policy should be on GPU" + assert dispatch._gpu_state["policy"].optimizer_on_gpu, "policy optimizer should be on GPU" + assert not dispatch._gpu_state["ref"].model_on_gpu, "ref should be offloaded" + + # Verify policy is on GPU and ref was offloaded + policy_mem = get_rank_0_memory(policy_group, "After policy forward_backward") + ref_mem_after_offload = get_rank_0_memory(ref_group, "Ref after being offloaded") + assert policy_mem > 1e8, f"Policy model should use significant memory: {policy_mem}" + assert ( + ref_mem_after_offload < ref_mem_after_load + ), f"Ref memory should decrease after offload: {ref_mem_after_offload} < {ref_mem_after_load}" + + # === Test 3: Switch back to ref (should offload policy) === + dispatch.forward("ref", dummy_batch) + + # Verify state tracking + assert dispatch._gpu_state["ref"].model_on_gpu, "ref should be on GPU" + assert not dispatch._gpu_state["policy"].model_on_gpu, "policy should be offloaded" + assert not dispatch._gpu_state["policy"].optimizer_on_gpu, "policy optimizer should be offloaded" + + # Verify policy was offloaded + policy_mem_after_offload = get_rank_0_memory(policy_group, "Policy after being offloaded") + assert ( + policy_mem_after_offload < policy_mem + ), f"Policy memory should decrease after offload: {policy_mem_after_offload} < {policy_mem}" + + finally: + ray.shutdown() + + +@pytest.mark.asyncio +async def test_gpu_state_tracking_accuracy(ray_init_fixture): + """ + Test that _gpu_state accurately reflects what's actually on GPU. + + This verifies the internal state tracking matches the actual offload/onload operations. + """ + cfg = get_test_config() + + try: + pg = placement_group([{"GPU": 1, "CPU": 2}], strategy="PACK") + get_ray_pg_ready_with_timeout(pg, timeout=30) + + policy_group = init_colocated_actor_group(PolicyWorker, pg, cfg) + ref_group = init_colocated_actor_group(RefWorker, pg, cfg) + + ray.get(policy_group.async_init_model(cfg.trainer.policy.model.path)) + ray.get(ref_group.async_init_model(cfg.trainer.policy.model.path)) + + dispatch = WorkerDispatch( + cfg, + policy_actor_group=policy_group, + ref_actor_group=ref_group, + ) + + # Start from clean state + policy_group.offload_to_cpu() + ref_group.offload_to_cpu() + dispatch.mark_all_offloaded() + + # Verify initial state + assert dispatch._gpu_state["policy"] == GPUState(model_on_gpu=False, optimizer_on_gpu=False) + assert dispatch._gpu_state["ref"] == GPUState(model_on_gpu=False, optimizer_on_gpu=False) + + # Load policy for training (needs model + optimizer) + dp_size = policy_group.actor_infos[0].rank.dp_size + dummy_batch = make_dummy_training_batch(batch_size=dp_size) + dispatch.forward_backward("policy", dummy_batch) + + assert dispatch._gpu_state["policy"] == GPUState(model_on_gpu=True, optimizer_on_gpu=True) + assert dispatch._gpu_state["ref"] == GPUState(model_on_gpu=False, optimizer_on_gpu=False) + + # Load ref for inference (only needs model) + dispatch.forward("ref", dummy_batch) + + assert dispatch._gpu_state["ref"] == GPUState(model_on_gpu=True, optimizer_on_gpu=False) + assert dispatch._gpu_state["policy"] == GPUState(model_on_gpu=False, optimizer_on_gpu=False) + + finally: + ray.shutdown() + + +@pytest.mark.asyncio +async def test_colocate_policy_critic_training_switch(ray_init_fixture): + """ + Test switching between policy and critic training with colocate_all=True. + + This tests the common PPO training pattern where we alternate between + training policy and critic on the same GPU. + + Scenario: + 1. Train policy (forward_backward + optim_step) + 2. Train critic (forward_backward + optim_step) + 3. Train policy again + 4. Verify correct offload/onload at each switch + """ + cfg = get_test_config() + + try: + pg = placement_group([{"GPU": 1, "CPU": 2}], strategy="PACK") + get_ray_pg_ready_with_timeout(pg, timeout=30) + + policy_group = init_colocated_actor_group(PolicyWorker, pg, cfg) + critic_group = init_colocated_actor_group(CriticWorker, pg, cfg) + + ray.get(policy_group.async_init_model(cfg.trainer.policy.model.path)) + ray.get(critic_group.async_init_model(cfg.trainer.policy.model.path)) + + dispatch = WorkerDispatch( + cfg, + policy_actor_group=policy_group, + critic_actor_group=critic_group, + ) + + # Start from clean state + policy_group.offload_to_cpu() + critic_group.offload_to_cpu() + dispatch.mark_all_offloaded() + + dp_size = policy_group.actor_infos[0].rank.dp_size + dummy_batch = make_dummy_training_batch(batch_size=dp_size) + + # === Step 1: Train policy === + dispatch.forward_backward("policy", dummy_batch) + dispatch.optim_step("policy") + + assert dispatch._gpu_state["policy"].model_on_gpu + assert dispatch._gpu_state["policy"].optimizer_on_gpu + assert not dispatch._gpu_state["critic"].model_on_gpu + assert not dispatch._gpu_state["critic"].optimizer_on_gpu + + policy_mem = get_rank_0_memory(policy_group, "After policy training") + assert policy_mem > 1e8, f"Policy model should use significant memory: {policy_mem}" + + # === Step 2: Train critic (should offload policy) === + dispatch.forward_backward("critic", dummy_batch) + dispatch.optim_step("critic") + + assert dispatch._gpu_state["critic"].model_on_gpu + assert dispatch._gpu_state["critic"].optimizer_on_gpu + assert not dispatch._gpu_state["policy"].model_on_gpu + assert not dispatch._gpu_state["policy"].optimizer_on_gpu + + # Verify critic is loaded and policy was offloaded + critic_mem = get_rank_0_memory(critic_group, "After critic training") + policy_mem_after_offload = get_rank_0_memory(policy_group, "Policy after offload") + assert critic_mem > 1e8, f"Critic model should use significant memory: {critic_mem}" + assert ( + policy_mem_after_offload < policy_mem + ), f"Policy memory should decrease after offload: {policy_mem_after_offload} < {policy_mem}" + + # === Step 3: Train policy again (should offload critic) === + dispatch.forward_backward("policy", dummy_batch) + dispatch.optim_step("policy") + + assert dispatch._gpu_state["policy"].model_on_gpu + assert dispatch._gpu_state["policy"].optimizer_on_gpu + assert not dispatch._gpu_state["critic"].model_on_gpu + assert not dispatch._gpu_state["critic"].optimizer_on_gpu + + # Verify policy is loaded again and critic was offloaded + policy_mem_reloaded = get_rank_0_memory(policy_group, "Policy reloaded") + critic_mem_after_offload = get_rank_0_memory(critic_group, "Critic after offload") + assert policy_mem_reloaded > 1e8, f"Policy should be back on GPU: {policy_mem_reloaded}" + assert ( + critic_mem_after_offload < critic_mem + ), f"Critic memory should decrease after offload: {critic_mem_after_offload} < {critic_mem}" + + finally: + ray.shutdown() diff --git a/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py b/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py index e82f4ac54..450aa7511 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_worker_offload.py @@ -9,7 +9,7 @@ import os import shutil -from tests.gpu.utils import init_worker_with_type, make_dummy_experience, make_dummy_tensorbatch, get_rank_0_memory +from tests.gpu.utils import init_worker_with_type, make_dummy_training_batch, make_dummy_tensorbatch, get_rank_0_memory from skyrl_train.utils.utils import validate_cfg from skyrl_train.entrypoints.main_base import config_dir from skyrl_train.training_batch import TrainingOutputBatch @@ -92,9 +92,10 @@ async def test_critic_policy_offload_memory_and_correctness(ray_init_fixture, cf actor_group.backload_to_gpu() get_rank_0_memory(actor_group, "Before training") - dummy_experience = make_dummy_experience() + dp_size = actor_group.actor_infos[0].rank.dp_size + dummy_batch = make_dummy_training_batch(batch_size=dp_size) # Run first forward_backward + optim_step to get optimizer initialized and stepped - results = ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_experience, 1)) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) after_training = get_rank_0_memory(actor_group, "After training") @@ -140,9 +141,7 @@ async def test_critic_policy_offload_memory_and_correctness(ray_init_fixture, cf ), f"Memory after backload model should be greater than after backload optimizer: {after_backload} bytes, after backload optimizer: {after_backload_optimizer} bytes" # Run training again and ensure output consistency - results_backload = ray.get( - actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_experience, 1) - ) + results_backload = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) for i, result in enumerate(results): @@ -330,11 +329,12 @@ def test_offload_after_ckpt(ray_init_fixture, strategy): ) get_rank_0_memory(actor_group, "After init") - # Create dummy experiences for training steps - dummy_experience_1 = make_dummy_experience() # First training step + # Create dummy training batch for training steps + dp_size = actor_group.actor_infos[0].rank.dp_size + dummy_batch_1 = make_dummy_training_batch(batch_size=dp_size) # Step 1: Do initial forward_backward + optim_step - ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", dummy_experience_1, 1)) + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=dummy_batch_1)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) get_rank_0_memory(actor_group, "After training step 1") diff --git a/skyrl-train/tests/gpu/test_save_load_model.py b/skyrl-train/tests/gpu/test_save_load_model.py index c0c48e826..71593cb86 100644 --- a/skyrl-train/tests/gpu/test_save_load_model.py +++ b/skyrl-train/tests/gpu/test_save_load_model.py @@ -21,7 +21,7 @@ from tests.gpu.utils import ( init_worker_with_type, - make_dummy_experience, + make_dummy_training_batch, get_model_logits_from_actor, ray_init_for_tests, validate_cfg, @@ -54,7 +54,7 @@ def get_test_actor_config(strategy: str) -> DictConfig: def run_one_training_step( actor_group, strategy, - experience=None, + data=None, megatron_batch=None, ): """Run forward_backward + optim_step to perform one training step.""" @@ -62,8 +62,8 @@ def run_one_training_step( assert megatron_batch is not None, "Megatron requires a TrainingInputBatch for ppo_train" return ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", megatron_batch)) else: - assert experience is not None, f"{strategy} requires an Experience for forward_backward" - ray.get(actor_group.async_run_ray_method("pass_through", "forward_backward", experience, 1)) + assert data is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=data)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) @@ -97,23 +97,23 @@ def test_save_load_hf_model(ray_init_fixture, strategy): ) # Prepare training input and run one training step + dp_size = actor_group_1.actor_infos[0].rank.dp_size if "megatron" in strategy: from tests.gpu.test_megatron_worker import get_test_training_batch - dp_size = actor_group_1.actor_infos[0].rank.dp_size train_batch_1 = get_test_training_batch(dp_size if dp_size % NUM_GPUS == 0 else NUM_GPUS) run_one_training_step( actor_group_1, strategy, - experience=None, + data=None, megatron_batch=train_batch_1, ) else: - dummy_experience = make_dummy_experience() + dummy_batch = make_dummy_training_batch(batch_size=dp_size) run_one_training_step( actor_group_1, strategy, - experience=dummy_experience, + data=dummy_batch, megatron_batch=None, )