diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index d34f53e5ab..34c35ee91b 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -228,6 +228,13 @@ generator: # number of samples per prompt for evaluation eval_n_samples_per_prompt: 1 + + # Trajectory logging configuration + trajectory_logging: + enabled: false # Set to true to enable trajectory logging + type: "wandb" # Type of logger to use ("wandb" or "csv") + max_trajectories: -1 # Maximum number of trajectories to log per batch (-1 for unlimited) + output_dir: "${trainer.export_path}/trajectory_logs" # Output directory for CSV logger # NOTE (sumanthrh): This flag sets the reward to 0 if the `stop_reason` is not `stop`. # This is useful in cases where the LLM generation was truncated or aborted. diff --git a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py index eacf37677f..f40d8f362a 100644 --- a/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py +++ b/skyrl-train/skyrl_train/generators/skyrl_gym_generator.py @@ -26,6 +26,7 @@ apply_overlong_filtering, get_rollout_metrics, ) +from skyrl_train.generators.trajectory_logger import create_trajectory_logger_from_config @dataclass @@ -73,7 +74,15 @@ def __init__( ) else: self.env_executor = None - + + # Initialize trajectory logging if enabled + if generator_cfg.trajectory_logging.enabled: + self.trajectory_logger = create_trajectory_logger_from_config( + generator_cfg.trajectory_logging + ) + else: + self.trajectory_logger = None + if getattr(self.generator_cfg.sampling_params, "logprobs", None) is not None and not self.generator_cfg.batched: raise ValueError("`sampling_params.logprobs` should be `None` if `batched` is `False`") @@ -448,6 +457,17 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput: rollout_logprobs = [output.rollout_logprobs for output in all_outputs] else: rollout_logprobs = None + + # Log trajectories if logging is enabled + if self.trajectory_logger: + # Detokenize prompts and responses for logging + log_prompts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in prompt_token_ids] + log_responses = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in responses] + + # Handle both single float and list of floats for rewards + log_rewards = [sum(r) if isinstance(r, list) else r for r in rewards] + + self.trajectory_logger.log(log_prompts, log_responses, log_rewards) rollout_metrics = get_rollout_metrics(responses, rewards) @@ -666,3 +686,5 @@ def _get_next_input_ids_with_single_turn_chat_template( input_ids += obs_tokens return loss_mask, input_ids, logprobs, response_end_idx + + diff --git a/skyrl-train/skyrl_train/generators/trajectory_logger.py b/skyrl-train/skyrl_train/generators/trajectory_logger.py new file mode 100644 index 0000000000..d0e6377ad0 --- /dev/null +++ b/skyrl-train/skyrl_train/generators/trajectory_logger.py @@ -0,0 +1,101 @@ +""" +Trajectory logging utilities for debugging and analysis during training. + +This module provides a simple framework for logging prompts, responses, and rewards. +""" + +import os +from abc import ABC, abstractmethod +from typing import List, Optional +import pandas as pd +from omegaconf import DictConfig + + +class TrajectoryLogger(ABC): + """ + Abstract base class for trajectory logging. + + TODO: Allow users to bring a custom trajectory logger. They should be able to + define their own class (outside of the skyrl-train package) and add it to + a registry (see AdvantageEstimatorRegistry for an example) so that it can + be referenced by name in the config. + """ + + def __init__(self, max_trajectories: int = -1): + self.max_trajectories = max_trajectories + + @abstractmethod + def log( + self, + prompts: List[str], + responses: List[str], + rewards: List[float] + ) -> None: + pass + + +class WandbTableTrajectoryLogger(TrajectoryLogger): + + def __init__(self, max_trajectories: int = -1): + super().__init__(max_trajectories) + import wandb + self.wandb = wandb + + def log( + self, + prompts: List[str], + responses: List[str], + rewards: List[float] + ) -> None: + if self.max_trajectories > 0: + prompts = prompts[:self.max_trajectories] + responses = responses[:self.max_trajectories] + rewards = rewards[:self.max_trajectories] + + data = [[i, prompt, response, reward] for i, (prompt, response, reward) in enumerate(zip(prompts, responses, rewards))] + table = self.wandb.Table(columns=["step", "prompt", "response", "reward"], data=data) + + self.wandb.log({"trajectories": table}) + + +class CSVTrajectoryLogger(TrajectoryLogger): + + def __init__(self, output_dir: str, max_trajectories: int = -1): + super().__init__(max_trajectories) + self.output_dir = output_dir + os.makedirs(output_dir, exist_ok=True) + + def log( + self, + prompts: List[str], + responses: List[str], + rewards: List[float] + ) -> None: + if self.max_trajectories > 0: + prompts = prompts[:self.max_trajectories] + responses = responses[:self.max_trajectories] + rewards = rewards[:self.max_trajectories] + + df = pd.DataFrame({ + "step": list(range(len(prompts))), + "prompt": prompts, + "response": responses, + "reward": rewards + }) + + filename = os.path.join(self.output_dir, "trajectories.csv") + df.to_csv(filename, index=False) + + +def create_trajectory_logger_from_config(logging_cfg: DictConfig) -> TrajectoryLogger: + assert logging_cfg.enabled + + if logging_cfg.type == 'wandb': + return WandbTableTrajectoryLogger(max_trajectories=logging_cfg.max_trajectories) + elif logging_cfg.type == 'csv': + return CSVTrajectoryLogger( + output_dir=logging_cfg.output_dir, + max_trajectories=logging_cfg.max_trajectories + ) + else: + raise ValueError(f"Unknown trajectory logger type: {logging_cfg.type}") \ No newline at end of file diff --git a/skyrl-train/tests/cpu/test_trajectory_logging_integration.py b/skyrl-train/tests/cpu/test_trajectory_logging_integration.py new file mode 100644 index 0000000000..5353f1a2f9 --- /dev/null +++ b/skyrl-train/tests/cpu/test_trajectory_logging_integration.py @@ -0,0 +1,293 @@ +""" +Integration tests for trajectory logging feature. + +Tests the complete trajectory logging workflow with real configurations: +- Configuration loading and parsing +- Trainer integration with trajectory logging +- Generator trajectory collection +- End-to-end logging functionality + +Run with: uv run --extra dev --isolated pytest tests/cpu/test_trajectory_logging_integration.py +""" + +import tempfile +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import hydra +import pandas as pd +import pytest +from omegaconf import DictConfig + +from skyrl_train.entrypoints.main_base import config_dir +from skyrl_train.generators.base import GeneratorInput +from skyrl_train.generators.skyrl_gym_generator import SkyRLGymGenerator +from skyrl_train.trainer import RayPPOTrainer +from skyrl_train.generators.trajectory_logger import create_trajectory_logger_from_config +from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput + + +@pytest.fixture +def base_config() -> DictConfig: + """Load base configuration for testing.""" + with hydra.initialize_config_dir(config_dir=config_dir): + cfg = hydra.compose(config_name="ppo_base_config") + + # Override for CPU testing + cfg.trainer.placement.policy_num_gpus_per_node = 0 + cfg.trainer.placement.critic_num_gpus_per_node = 0 + cfg.trainer.train_batch_size = 2 + cfg.trainer.micro_train_batch_size_per_gpu = 1 + cfg.trainer.epochs = 1 + cfg.trainer.logger = "console" + cfg.generator.n_samples_per_prompt = 1 + + return cfg + + +@pytest.fixture +def trajectory_logging_config(base_config) -> DictConfig: + """Configure trajectory logging in the base config.""" + base_config.generator.trajectory_logging.enabled = True + base_config.generator.trajectory_logging.type = "csv" + base_config.generator.trajectory_logging.max_trajectories = 5 + base_config.generator.trajectory_logging.output_dir = "./test_trajectory_logs" + return base_config + + +@pytest.fixture +def mock_env_cfg(): + """Create mock environment configuration.""" + cfg = MagicMock() + cfg.max_env_workers = 0 + cfg.env_class = "gsm8k" + return cfg + + +@pytest.fixture +def mock_tokenizer(): + """Create a realistic mock tokenizer.""" + tokenizer = MagicMock() + + def mock_apply_chat_template(messages, **kwargs): + if kwargs.get("tokenize", True): + if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list): + return [[1, 2, 3, 4] for _ in messages] + else: + return [1, 2, 3, 4] + else: + if isinstance(messages, list) and isinstance(messages[0], dict): + return " ".join([msg.get("content", "") for msg in messages]) + return str(messages) + + def mock_decode(tokens, **kwargs): + return f"decoded_{len(tokens)}_tokens" + + def mock_tokenizer_call(text): + return {"input_ids": [1, 2, 3, 4]} + + tokenizer.apply_chat_template.side_effect = mock_apply_chat_template + tokenizer.decode.side_effect = mock_decode + tokenizer.side_effect = mock_tokenizer_call + tokenizer.eos_token_id = 4 + tokenizer.eos_token = "" + tokenizer.pad_token_id = 0 + + return tokenizer + + +@pytest.fixture +def mock_inference_engine(): + """Create a mock inference engine.""" + engine = MagicMock() + + async def mock_generate(input_data, **kwargs): + # Handle both old and new interface + if hasattr(input_data, 'prompts'): + num_prompts = len(input_data.prompts) + elif hasattr(input_data, 'prompt_token_ids'): + num_prompts = len(input_data.prompt_token_ids) + else: + num_prompts = 1 + + return { + "responses": [f"Generated response {i}" for i in range(num_prompts)], + "response_ids": [[1, 2, 3, 4] for i in range(num_prompts)], + "stop_reasons": ["stop"] * num_prompts + } + + engine.generate = AsyncMock(side_effect=mock_generate) + return engine + + +@pytest.fixture +def mock_env(): + """Create a mock SkyRL environment.""" + env = MagicMock() + + # Set return values instead of side_effect + env.init.return_value = ([{"role": "user", "content": "Test question"}], {"test_param": "value"}) + + def mock_step(response): + return BaseTextEnvStepOutput( + observations=[{"role": "assistant", "content": response}], + reward=1.0, + done=True, + metadata={"success": True} + ) + + env.step.side_effect = mock_step + return env + + +class TestTrajectoryLoggingIntegration: + """Integration tests for trajectory logging feature.""" + + def test_config_loading(self, trajectory_logging_config): + """Test that trajectory logging configuration is properly loaded.""" + cfg = trajectory_logging_config + + # Verify trajectory logging config is present + assert hasattr(cfg.generator, 'trajectory_logging') + assert cfg.generator.trajectory_logging.enabled is True + assert cfg.generator.trajectory_logging.type == "csv" + assert cfg.generator.trajectory_logging.max_trajectories == 5 + + @patch("skyrl_gym.make") + def test_generator_trajectory_collection( + self, mock_make, trajectory_logging_config, mock_tokenizer, + mock_inference_engine, mock_env, mock_env_cfg + ): + """Test that SkyRLGymGenerator collects trajectories when configured.""" + mock_make.return_value = mock_env + + with tempfile.TemporaryDirectory() as tmpdir: + # Update config with test directory + trajectory_logging_config.generator.trajectory_logging.output_dir = tmpdir + + # Create generator with trajectory logging from config + generator = SkyRLGymGenerator( + generator_cfg=trajectory_logging_config.generator, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_inference_engine, + tokenizer=mock_tokenizer, + model_name="test_model", + ) + + # Verify trajectory logger was created + assert generator.trajectory_logger is not None + + @pytest.mark.asyncio + @patch("skyrl_gym.make") + async def test_end_to_end_trajectory_logging( + self, mock_make, trajectory_logging_config, mock_tokenizer, + mock_inference_engine, mock_env, mock_env_cfg + ): + """Test complete end-to-end trajectory logging workflow.""" + mock_make.return_value = mock_env + + with tempfile.TemporaryDirectory() as tmpdir: + # Update config with test directory + trajectory_logging_config.generator.trajectory_logging.output_dir = tmpdir + + # Create generator with dependency injection + generator = SkyRLGymGenerator( + generator_cfg=trajectory_logging_config.generator, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_inference_engine, + tokenizer=mock_tokenizer, + model_name="test_model", + ) + + # Generate trajectories + input_batch: GeneratorInput = { + "prompts": [ + [{"role": "user", "content": "What is machine learning?"}], + [{"role": "user", "content": "Explain neural networks"}] + ], + "env_extras": [{"topic": "ML"}, {"topic": "DL"}], + "env_classes": ["test_env", "test_env"], + } + + # Execute generation - this should automatically trigger trajectory logging + generator_output = await generator.generate(input_batch) + + # Verify trajectory logging worked by checking for CSV file + csv_file = os.path.join(tmpdir, "trajectories.csv") + assert os.path.exists(csv_file), f"Expected trajectory file not found at {csv_file}" + + # Verify CSV content contains expected trajectory data + df = pd.read_csv(csv_file) + + # Should have 2 trajectories from the input batch + assert len(df) == 2 + + # Verify simplified structure + assert "prompt" in df.columns + assert "response" in df.columns + assert "reward" in df.columns + assert "step" in df.columns + + # Verify step values are indices (0, 1 for 2 trajectories) + expected_steps = list(range(len(df))) + assert list(df["step"]) == expected_steps, f"Expected steps {expected_steps}, got {list(df['step'])}" + assert all(df["reward"] == 1.0) + + def test_trainer_trajectory_logging_integration(self, trajectory_logging_config, mock_tokenizer): + """Test that RayPPOTrainer integrates with trajectory logging configuration.""" + # Create minimal dataset for trainer + class DummyDataset: + def __len__(self): + return 2 + + def __getitem__(self, idx): + return f"sample_{idx}" + + def collate_fn(self, batch): + return batch + + # Create mock generator + mock_generator = MagicMock() + mock_generator.trajectory_batch = [] + + # Create trainer + trainer = RayPPOTrainer( + cfg=trajectory_logging_config, + tracker=None, + tokenizer=mock_tokenizer, + train_dataset=DummyDataset(), + eval_dataset=None, + inference_engine_client=None, + generator=mock_generator, + ) + + # Verify trajectory logging configuration is accessible + assert hasattr(trainer.cfg.generator, 'trajectory_logging') + assert trainer.cfg.generator.trajectory_logging.enabled is True + + # Test that trajectory logging is properly configured + traj_config = trainer.cfg.generator.trajectory_logging + assert traj_config.type == "csv" + assert traj_config.max_trajectories == 5 + + @patch("skyrl_gym.make") + def test_trajectory_logging_disabled( + self, mock_make, base_config, mock_tokenizer, mock_inference_engine, mock_env, mock_env_cfg + ): + """Test that trajectory logging can be disabled via configuration.""" + mock_make.return_value = mock_env + + # Ensure trajectory logging is disabled + base_config.generator.trajectory_logging.enabled = False + + # Create generator without trajectory logger + generator = SkyRLGymGenerator( + generator_cfg=base_config.generator, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_inference_engine, + tokenizer=mock_tokenizer, + model_name="test_model", + ) + + # Verify no trajectory collection + assert generator.trajectory_logger is None \ No newline at end of file diff --git a/skyrl-train/tests/cpu/utils/test_trajectory_logger.py b/skyrl-train/tests/cpu/utils/test_trajectory_logger.py new file mode 100644 index 0000000000..e89d37915c --- /dev/null +++ b/skyrl-train/tests/cpu/utils/test_trajectory_logger.py @@ -0,0 +1,194 @@ +""" +Unit tests for trajectory logging functionality. + +This module tests the simplified TrajectoryLogger system with prompts, responses, and rewards. + +Run with: uv run --extra dev --isolated pytest tests/cpu/utils/test_trajectory_logger.py +""" + +import os +import tempfile +from unittest.mock import MagicMock, patch + +from omegaconf import DictConfig +import pandas as pd +import pytest + +from skyrl_train.generators.trajectory_logger import ( + CSVTrajectoryLogger, + TrajectoryLogger, + WandbTableTrajectoryLogger, + create_trajectory_logger_from_config, +) + + +# Test Fixtures +@pytest.fixture +def sample_data(): + """Create simple sample trajectory data.""" + return { + "prompts": ["What is 2+2?", "What is 3+3?", "What is 4+4?"], + "responses": ["The answer is 4.", "The answer is 6.", "The answer is 8."], + "rewards": [1.0, 0.8, 0.5] + } + + +@pytest.fixture +def concrete_logger(): + """Create a concrete implementation of TrajectoryLogger for testing.""" + class ConcreteLogger(TrajectoryLogger): + def log(self, prompts, responses, rewards): + pass + return ConcreteLogger() + + +class TestWandbTableTrajectoryLogger: + """Test WandB table trajectory logger.""" + + def test_initialization(self): + """Test WandB logger initialization.""" + mock_wandb_module = MagicMock() + + with patch.dict('sys.modules', {'wandb': mock_wandb_module}): + logger = WandbTableTrajectoryLogger(max_trajectories=5) + + assert logger.max_trajectories == 5 + assert logger.wandb == mock_wandb_module + + def test_log_trajectories(self, sample_data): + """Test logging trajectories to wandb.""" + mock_wandb_module = MagicMock() + mock_table = MagicMock() + mock_wandb_module.Table.return_value = mock_table + + with patch.dict('sys.modules', {'wandb': mock_wandb_module}): + logger = WandbTableTrajectoryLogger(max_trajectories=2) + + logger.log( + sample_data["prompts"], + sample_data["responses"], + sample_data["rewards"] + ) + + # Verify table creation + mock_wandb_module.Table.assert_called_once_with( + columns=["step", "prompt", "response", "reward"] + ) + + # Verify trajectory limiting (max_trajectories=2) + assert mock_table.add_data.call_count == 2 + + # Verify wandb.log call with fixed prefix + mock_wandb_module.log.assert_called_once() + log_args = mock_wandb_module.log.call_args[0][0] + assert "trajectories" in log_args + + def test_unlimited_trajectories(self, sample_data): + """Test logging with unlimited trajectories.""" + mock_wandb_module = MagicMock() + mock_table = MagicMock() + mock_wandb_module.Table.return_value = mock_table + + with patch.dict('sys.modules', {'wandb': mock_wandb_module}): + logger = WandbTableTrajectoryLogger(max_trajectories=-1) + + logger.log( + sample_data["prompts"], + sample_data["responses"], + sample_data["rewards"] + ) + + # Should log all 3 trajectories + assert mock_table.add_data.call_count == 3 + + +class TestCSVTrajectoryLogger: + """Test CSV trajectory logger.""" + + def test_initialization(self): + """Test CSV logger initialization.""" + with tempfile.TemporaryDirectory() as tmpdir: + logger = CSVTrajectoryLogger(output_dir=tmpdir, max_trajectories=5) + + assert logger.output_dir == tmpdir + assert logger.max_trajectories == 5 + assert os.path.exists(tmpdir) + + def test_log_trajectories_to_csv(self, sample_data): + """Test logging trajectories to CSV file.""" + with tempfile.TemporaryDirectory() as tmpdir: + logger = CSVTrajectoryLogger(output_dir=tmpdir, max_trajectories=2) + + logger.log( + sample_data["prompts"], + sample_data["responses"], + sample_data["rewards"] + ) + + # Verify CSV file creation + csv_file = os.path.join(tmpdir, "trajectories.csv") + assert os.path.exists(csv_file) + + # Verify CSV content + df = pd.read_csv(csv_file) + assert len(df) == 2 # Limited by max_trajectories + + # Verify step values are indices (0, 1) + assert list(df["step"]) == [0, 1] + assert "What is 2+2?" in df["prompt"].values + assert "The answer is 4." in df["response"].values + assert 1.0 in df["reward"].values + + +class TestTrajectoryLoggerFactory: + """Test the trajectory logger factory function.""" + + def test_create_wandb_logger(self): + """Test creating WandB logger from configuration.""" + config = DictConfig({ + "enabled": True, + "type": "wandb", + "max_trajectories": 5 + }) + + with patch.dict('sys.modules', {'wandb': MagicMock()}): + logger = create_trajectory_logger_from_config(config) + + assert isinstance(logger, WandbTableTrajectoryLogger) + assert logger.max_trajectories == 5 + + def test_create_csv_logger(self): + """Test creating CSV logger from configuration.""" + with tempfile.TemporaryDirectory() as tmpdir: + config = DictConfig({ + "enabled": True, + "type": "csv", + "output_dir": tmpdir, + "max_trajectories": 10 + }) + + logger = create_trajectory_logger_from_config(config) + + assert isinstance(logger, CSVTrajectoryLogger) + assert logger.output_dir == tmpdir + assert logger.max_trajectories == 10 + + def test_disabled_config_assertion(self): + """Test that disabled configuration raises assertion error.""" + config = DictConfig({ + "enabled": False, + "type": "wandb" + }) + + with pytest.raises(AssertionError): + create_trajectory_logger_from_config(config) + + def test_unknown_logger_type_error(self): + """Test that unknown logger type raises ValueError.""" + config = DictConfig({ + "enabled": True, + "type": "unknown" + }) + + with pytest.raises(ValueError, match="Unknown trajectory logger type"): + create_trajectory_logger_from_config(config) \ No newline at end of file