diff --git a/skyrl-train/docs/checkpointing-logging/checkpointing.rst b/skyrl-train/docs/checkpointing-logging/checkpointing.rst index a84a7c5b35..489df0b5bf 100644 --- a/skyrl-train/docs/checkpointing-logging/checkpointing.rst +++ b/skyrl-train/docs/checkpointing-logging/checkpointing.rst @@ -28,24 +28,50 @@ FSDP checkpoints are organized according to the following directory hierarchy: .. code-block:: {ckpt_path}/ - ├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint - ├── global_step_10/ # Checkpoint at training step 10 + ├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint + ├── global_step_10/ # Checkpoint at training step 10 │ ├── policy/ # Policy model checkpoint directory + │ │ ├── fsdp_config.json # stores fsdp version and world size + │ │ ├── huggingface/ # HuggingFace config and tokenizer + │ │ ├── config.json # model config + │ │ ├── tokenizer_config.json # tokenizer config + │ │ ├── generation_config.json # generation config + │ │ ├── ... # other tokenizer config files │ │ ├── model_state.pt # Model parameters │ │ ├── optimizer_state.pt # Optimizer state │ │ └── lr_scheduler_state.pt # Learning rate scheduler state │ ├── critic/ # Critic model checkpoint (if enabled) + │ │ ├── fsdp_config.json + │ │ ├── huggingface/ │ │ ├── model_state.pt │ │ ├── optimizer_state.pt │ │ └── lr_scheduler_state.pt │ ├── data.pt # Dataloader state │ └── trainer_state.pt # High-level trainer state - ├── global_step_20/ # Checkpoint at training step 20 + ├── global_step_20/ # Checkpoint at training step 20 │ └── ... - └── global_step_30/ # Checkpoint at training step 30 + └── global_step_30/ # Checkpoint at training step 30 └── ... -DeepSpeed checkpoints follow a similar directory structure but the files under ``policy`` and ``critic`` are created by the DeepSpeed checkpoint API, and are not explicitly managed by SkyRL. +DeepSpeed checkpoints follow a similar directory structure but the model checkpoint files under ``policy`` and ``critic`` are created by the DeepSpeed checkpoint API, and are not explicitly managed by SkyRL. + +.. code-block:: + + {ckpt_path}/ + ├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint + ├── global_step_10/ # Checkpoint at training step 10 + │ ├── policy/ # Policy model checkpoint directory + │ │ ├── huggingface/ # HuggingFace config and tokenizer + │ │ ├── global_step10/ # Deepspeed checkpoint directory + │ │ ├── ... # other deepspeed checkpointing files + │ ├── critic/ # Critic model checkpoint (if enabled) + │ │ ├── huggingface/ + │ │ ├── global_step10/ + │ │ ├── ... + ├── global_step_20/ # Checkpoint at training step 20 + │ └── ... + └── global_step_30/ # Checkpoint at training step 30 + └── ... Key Configuration Parameters diff --git a/skyrl-train/skyrl_train/distributed/deepspeed_strategy.py b/skyrl-train/skyrl_train/distributed/deepspeed_strategy.py index 6b5dd65787..5d3e57885f 100644 --- a/skyrl-train/skyrl_train/distributed/deepspeed_strategy.py +++ b/skyrl-train/skyrl_train/distributed/deepspeed_strategy.py @@ -257,6 +257,7 @@ def save_ckpt( scheduler=None, client_state={}, tag=None, + tokenizer=None, ): if isinstance(model, Actor): model = model.model @@ -277,6 +278,11 @@ def save_ckpt( model.save_checkpoint(ckpt_dir, tag=tag, client_state=extra_state_dict) + # Save HuggingFace config and tokenizer + if self.is_rank_0(): + config_save_model = self._unwrap_model(model) + self.save_hf_configs(config_save_model, ckpt_dir, tokenizer) + def load_ckpt( self, model, diff --git a/skyrl-train/skyrl_train/distributed/fsdp_strategy.py b/skyrl-train/skyrl_train/distributed/fsdp_strategy.py index dc6ac4d8c2..19ca909ff2 100644 --- a/skyrl-train/skyrl_train/distributed/fsdp_strategy.py +++ b/skyrl-train/skyrl_train/distributed/fsdp_strategy.py @@ -6,6 +6,7 @@ from typing import List, Union, Optional from jaxtyping import Float import gc +import json import numpy as np import torch @@ -374,6 +375,7 @@ def save_ckpt( scheduler=None, client_state={}, tag=None, + tokenizer=None, ): """Save model checkpoint for FSDP""" import warnings @@ -445,6 +447,15 @@ def save_ckpt( # Garbage collect temporary buffers from materializing the state dicts gc.collect() + if self.is_rank_0(): + config_save_model = self._unwrap_model(model) + self.save_hf_configs(config_save_model, ckpt_dir, tokenizer) + + # Also save runtime FSDP config + fsdp_config_path = os.path.join(ckpt_dir, "fsdp_config.json") + with open(fsdp_config_path, "w") as f: + json.dump({"fsdp_strategy": self.fsdp_strategy, "world_size": self.world_size}, f, indent=4) + # Final barrier to ensure all operations complete dist.barrier() torch.cuda.synchronize() diff --git a/skyrl-train/skyrl_train/distributed/strategy.py b/skyrl-train/skyrl_train/distributed/strategy.py index 913177b8b3..f05ffedefd 100644 --- a/skyrl-train/skyrl_train/distributed/strategy.py +++ b/skyrl-train/skyrl_train/distributed/strategy.py @@ -1,4 +1,5 @@ import random +import os from abc import ABC, abstractmethod import numpy as np @@ -7,6 +8,7 @@ from typing import Optional, Dict, Any, Union, TypeVar import torch.optim as optim from jaxtyping import Float +from transformers import GenerationConfig DataT = TypeVar("DataT", bound=Union[Dict[str, Any], torch.Tensor]) @@ -45,7 +47,7 @@ def optimizer_step( pass @abstractmethod - def save_ckpt(self, model, optimizer, scheduler, ckpt_dir, global_step, node_local_rank): + def save_ckpt(self, model, optimizer, scheduler, ckpt_dir, global_step, node_local_rank, tokenizer=None): """Save checkpoint""" pass @@ -72,6 +74,34 @@ def get_rank(self) -> int: """Get current process rank""" return dist.get_rank() + def save_hf_configs(self, model, ckpt_dir: str, tokenizer=None): + """ + Save model and tokenizer configs to ckpt_dir/huggingface + + Args: + model: AutoModel - the model to save the configs for + ckpt_dir: str - the directory to save the configs to + tokenizer: AutoTokenizer - tokenizer to save + """ + hf_config_tokenizer_path = os.path.join(ckpt_dir, "huggingface") + os.makedirs(hf_config_tokenizer_path, exist_ok=True) + model_config = model.config + generation_config = None + if model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path: + try: + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + generation_config.save_pretrained(hf_config_tokenizer_path) + except Exception as e: + # if the generation config isn't available, we don't save it + print(f"Warning: Could not save generation config for '{model_config.name_or_path}'. Error: {e}") + pass + + model_config.save_pretrained(hf_config_tokenizer_path) + if tokenizer is not None: + tokenizer.save_pretrained(hf_config_tokenizer_path) + @staticmethod def get_rng_state(): """Get current RNG state for reproducibility""" diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 7edda4310a..e8f576c8ce 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1037,6 +1037,7 @@ def save_checkpoints(self): "save_ckpt", global_step=self.global_step, ckpt_dir=policy_save_dir, + tokenizer=self.tokenizer, ) ) @@ -1052,6 +1053,7 @@ def save_checkpoints(self): "save_ckpt", global_step=self.global_step, ckpt_dir=critic_save_dir, + tokenizer=self.tokenizer, ) ) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 84d13206f9..112f445553 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -847,7 +847,7 @@ def training_step(self, experience: Experience, global_step, local_step, accumul status["response_length"] = num_actions return status - def save_ckpt(self, global_step: int, ckpt_dir: Path): + def save_ckpt(self, global_step: int, ckpt_dir: Path, tokenizer=None): self.strategy.save_ckpt( model=self.model, optimizer=self.optimizer, @@ -855,6 +855,7 @@ def save_ckpt(self, global_step: int, ckpt_dir: Path): ckpt_dir=ckpt_dir, global_step=global_step, node_local_rank=self.get_node_local_rank(), + tokenizer=tokenizer, ) def load_ckpt(self, ckpt_dir: Path, load_optimizer_states: bool = True, load_lr_scheduler_states: bool = True): @@ -1052,7 +1053,7 @@ def training_step(self, experience: Experience, global_step, local_step, accumul status["raw_grad_norm"] = grad_norm return status - def save_ckpt(self, global_step: int, ckpt_dir: str): + def save_ckpt(self, global_step: int, ckpt_dir: str, tokenizer=None): self.strategy.save_ckpt( model=self.model, optimizer=self.optimizer, @@ -1060,6 +1061,7 @@ def save_ckpt(self, global_step: int, ckpt_dir: str): ckpt_dir=ckpt_dir, global_step=global_step, node_local_rank=self.get_node_local_rank(), + tokenizer=tokenizer, ) def load_ckpt(self, ckpt_dir=None, load_optimizer_states=True, load_lr_scheduler_states=True): diff --git a/skyrl-train/tests/gpu/test_save_load_ckpt.py b/skyrl-train/tests/gpu/test_save_load_ckpt.py index 045bf469a9..73383fc33c 100644 --- a/skyrl-train/tests/gpu/test_save_load_ckpt.py +++ b/skyrl-train/tests/gpu/test_save_load_ckpt.py @@ -1,6 +1,6 @@ """ Run with: -uv run --isolated --extra dev -- pytest tests/gpu/test_save_load_ckpt.py +uv run --isolated --extra dev --with deepspeed -- pytest tests/gpu/test_save_load_ckpt.py """ import ray @@ -9,7 +9,9 @@ import torch import os import shutil +import json from omegaconf import DictConfig +from transformers import AutoTokenizer from tests.gpu.utils import init_worker_with_type, make_dummy_experience, get_model_logits_from_actor from skyrl_train.entrypoints.main_base import config_dir @@ -40,7 +42,7 @@ def get_test_actor_config(strategy: str) -> DictConfig: "fsdp2", ], ) -def test_save_load_checkpoint(strategy): +def test_save_load_checkpoint(ray_init_fixture, strategy): """ Test checkpointing logic by: 1. Creating model and doing one training step @@ -59,6 +61,7 @@ def test_save_load_checkpoint(strategy): num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, cfg=cfg, ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) checkpoint_dir = None # Create dummy experiences for training steps @@ -82,7 +85,25 @@ def test_save_load_checkpoint(strategy): checkpoint_dir = os.path.expandvars(os.path.join(cfg.trainer.ckpt_path, "global_step_1")) # Store for cleanup # Step 2: Save checkpoint - ray.get(actor_group.async_run_ray_method("pass_through", "save_ckpt", global_step=1, ckpt_dir=checkpoint_path)) + ray.get( + actor_group.async_run_ray_method( + "pass_through", "save_ckpt", global_step=1, ckpt_dir=checkpoint_path, tokenizer=tokenizer + ) + ) + + # check that relevant files are saved + huggingface_dir = os.path.join(checkpoint_path, "huggingface") + expected_files = ["config.json", "generation_config.json", "tokenizer.json"] + for file in expected_files: + assert os.path.exists( + os.path.join(huggingface_dir, file) + ), f"File {file} not found in huggingface directory" + if "fsdp" in strategy: + fsdp_config_path = os.path.join(checkpoint_path, "fsdp_config.json") + with open(fsdp_config_path, "r") as f: + fsdp_config = json.load(f) + assert fsdp_config["fsdp_strategy"] == strategy + assert fsdp_config["world_size"] == 2 # Step 3: Do second training step and record results ray.get( @@ -117,9 +138,6 @@ def test_save_load_checkpoint(strategy): torch.testing.assert_close(logits_after_second_training, logits_after_reload_and_training, atol=0.0, rtol=0.0) finally: - # Clean up ray - ray.shutdown() - # Clean up checkpoint directory if checkpoint_dir and os.path.exists(checkpoint_dir): print(f"Removing checkpoint directory: {checkpoint_dir}") diff --git a/skyrl-train/tests/gpu/test_trainer_full_checkpointing.py b/skyrl-train/tests/gpu/test_trainer_full_checkpointing.py index 81fe22d5d6..e688369edc 100644 --- a/skyrl-train/tests/gpu/test_trainer_full_checkpointing.py +++ b/skyrl-train/tests/gpu/test_trainer_full_checkpointing.py @@ -5,7 +5,7 @@ ensuring that training can resume exactly where it left off. Run with: -uv run --isolated --extra dev -- pytest tests/gpu/test_trainer_full_checkpointing.py +uv run --isolated --extra dev --with deepspeed -- pytest tests/gpu/test_trainer_full_checkpointing.py """ import ray @@ -18,10 +18,11 @@ from omegaconf import DictConfig from torch.utils.data import Dataset from unittest.mock import MagicMock +from transformers import AutoTokenizer from skyrl_train.utils.tracking import Tracking from skyrl_train.trainer import RayPPOTrainer -from tests.gpu.utils import import_worker +from tests.gpu.utils import import_worker, ray_init_for_tests from skyrl_train.entrypoints.main_base import config_dir MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" @@ -43,25 +44,6 @@ def collate_fn(self, batch): return batch -class MinimalTokenizer: - """Minimal tokenizer for testing""" - - def __init__(self): - self.pad_token_id = 0 - self.eos_token_id = 1 - self.vocab_size = 1000 - - def encode(self, text, **kwargs): - # Return dummy token IDs - return list(range(10)) - - def decode(self, token_ids, **kwargs): - return f"Decoded: {token_ids}" - - def apply_chat_template(self, messages, **kwargs): - return list(range(5)) # Return dummy tokens - - def get_test_trainer_config(strategy: str, fsdp2_cpu_offload: bool = False) -> DictConfig: """Create minimal trainer config for testing""" with hydra.initialize_config_dir(config_dir=config_dir): @@ -75,10 +57,10 @@ def get_test_trainer_config(strategy: str, fsdp2_cpu_offload: bool = False) -> D # Use minimal settings for faster testing cfg.trainer.placement.policy_num_gpus_per_node = 2 - cfg.trainer.placement.ref_num_gpus_per_node = 2 + cfg.trainer.placement.critic_num_gpus_per_node = 2 cfg.trainer.placement.policy_num_nodes = 1 cfg.trainer.placement.critic_num_nodes = 1 - cfg.trainer.placement.ref_num_nodes = 1 + cfg.trainer.algorithm.use_kl_loss = False # disable ref model so we just have policy and critic (4 GPUs) cfg.trainer.placement.colocate_all = False # Disable colocation for simpler testing cfg.trainer.train_batch_size = 2 cfg.trainer.micro_train_batch_size_per_gpu = 1 @@ -103,7 +85,7 @@ def get_test_trainer_config(strategy: str, fsdp2_cpu_offload: bool = False) -> D def create_minimal_trainer(cfg: DictConfig): """Create a minimal trainer setup for testing""" # Create minimal tokenizer - tokenizer = MinimalTokenizer() + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Create dummy dataset train_dataset = DummyDataset(size=4) # Small dataset for quick testing @@ -152,7 +134,7 @@ def capture_training_state(trainer): ("fsdp2", True), ], ) -def test_trainer_full_checkpointing(strategy, fsdp2_cpu_offload): +def test_trainer_full_checkpointing(ray_init_fixture, strategy, fsdp2_cpu_offload): """ Test full trainer checkpointing by: 1. Creating trainer and setting it up @@ -228,7 +210,7 @@ def test_trainer_full_checkpointing(strategy, fsdp2_cpu_offload): # ============= PHASE 2: Resume from Checkpoint ============= print("Phase 2: Resume from checkpoint") - + ray_init_for_tests() # Create new config with resume enabled cfg_resume = get_test_trainer_config(strategy, fsdp2_cpu_offload) cfg_resume.trainer.resume_mode = "from_path" # Enable resume @@ -275,12 +257,6 @@ def test_trainer_full_checkpointing(strategy, fsdp2_cpu_offload): assert latest_step == trainer2.global_step, "Atomic tracking file was not updated after second save" finally: - # Cleanup - try: - ray.shutdown() - except Exception as e: - print(f"Error shutting down Ray -- it may already be shut down. Error: {e}") - if checkpoint_dir and os.path.exists(os.path.dirname(checkpoint_dir)): print(f"Cleaning up checkpoint directory: {os.path.dirname(checkpoint_dir)}") shutil.rmtree(os.path.dirname(checkpoint_dir))