Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ generator:

# number of samples per prompt for evaluation
eval_n_samples_per_prompt: 1

# Trajectory logging configuration
trajectory_logging:
enabled: false # Set to true to enable trajectory logging
type: "wandb" # Type of logger to use ("wandb" or "csv")
max_trajectories: -1 # Maximum number of trajectories to log per batch (-1 for unlimited)
output_dir: "${trainer.export_path}/trajectory_logs" # Output directory for CSV logger

# NOTE (sumanthrh): This flag sets the reward to 0 if the `stop_reason` is not `stop`.
# This is useful in cases where the LLM generation was truncated or aborted.
Expand Down
24 changes: 23 additions & 1 deletion skyrl-train/skyrl_train/generators/skyrl_gym_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
apply_overlong_filtering,
get_rollout_metrics,
)
from skyrl_train.generators.trajectory_logger import create_trajectory_logger_from_config


@dataclass
Expand Down Expand Up @@ -73,7 +74,15 @@ def __init__(
)
else:
self.env_executor = None


# Initialize trajectory logging if enabled
if generator_cfg.trajectory_logging.enabled:
self.trajectory_logger = create_trajectory_logger_from_config(
generator_cfg.trajectory_logging
)
else:
self.trajectory_logger = None

if getattr(self.generator_cfg.sampling_params, "logprobs", None) is not None and not self.generator_cfg.batched:
raise ValueError("`sampling_params.logprobs` should be `None` if `batched` is `False`")

Expand Down Expand Up @@ -448,6 +457,17 @@ async def generate(self, input_batch: GeneratorInput) -> GeneratorOutput:
rollout_logprobs = [output.rollout_logprobs for output in all_outputs]
else:
rollout_logprobs = None

# Log trajectories if logging is enabled
if self.trajectory_logger:
# Detokenize prompts and responses for logging
log_prompts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in prompt_token_ids]
log_responses = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in responses]

# Handle both single float and list of floats for rewards
log_rewards = [sum(r) if isinstance(r, list) else r for r in rewards]

self.trajectory_logger.log(log_prompts, log_responses, log_rewards)

rollout_metrics = get_rollout_metrics(responses, rewards)

Expand Down Expand Up @@ -666,3 +686,5 @@ def _get_next_input_ids_with_single_turn_chat_template(
input_ids += obs_tokens

return loss_mask, input_ids, logprobs, response_end_idx


101 changes: 101 additions & 0 deletions skyrl-train/skyrl_train/generators/trajectory_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Trajectory logging utilities for debugging and analysis during training.

This module provides a simple framework for logging prompts, responses, and rewards.
"""

import os
from abc import ABC, abstractmethod
from typing import List, Optional
import pandas as pd
from omegaconf import DictConfig


class TrajectoryLogger(ABC):
"""
Abstract base class for trajectory logging.

TODO: Allow users to bring a custom trajectory logger. They should be able to
define their own class (outside of the skyrl-train package) and add it to
a registry (see AdvantageEstimatorRegistry for an example) so that it can
be referenced by name in the config.
"""
Comment on lines +14 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce code duplication, the trajectory slicing logic, which is currently duplicated in both WandbTableTrajectoryLogger and CSVTrajectoryLogger, could be moved into this base class.

You could use the template method pattern by making log a concrete method that handles slicing and then calls a new abstract method (e.g., _log) for the specific logging implementation. This would centralize the slicing logic and make the logger implementations cleaner.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines of duplicated slicing logic are worth it for the much cleaner user experience when implementing custom loggers


def __init__(self, max_trajectories: int = -1):
self.max_trajectories = max_trajectories

@abstractmethod
def log(
self,
prompts: List[str],
responses: List[str],
rewards: List[float]
) -> None:
pass


class WandbTableTrajectoryLogger(TrajectoryLogger):

def __init__(self, max_trajectories: int = -1):
super().__init__(max_trajectories)
import wandb
self.wandb = wandb

def log(
self,
prompts: List[str],
responses: List[str],
rewards: List[float]
) -> None:
if self.max_trajectories > 0:
prompts = prompts[:self.max_trajectories]
responses = responses[:self.max_trajectories]
rewards = rewards[:self.max_trajectories]

data = [[i, prompt, response, reward] for i, (prompt, response, reward) in enumerate(zip(prompts, responses, rewards))]
table = self.wandb.Table(columns=["step", "prompt", "response", "reward"], data=data)

self.wandb.log({"trajectories": table})


class CSVTrajectoryLogger(TrajectoryLogger):

def __init__(self, output_dir: str, max_trajectories: int = -1):
super().__init__(max_trajectories)
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)

def log(
self,
prompts: List[str],
responses: List[str],
rewards: List[float]
) -> None:
if self.max_trajectories > 0:
prompts = prompts[:self.max_trajectories]
responses = responses[:self.max_trajectories]
rewards = rewards[:self.max_trajectories]

df = pd.DataFrame({
"step": list(range(len(prompts))),
"prompt": prompts,
"response": responses,
"reward": rewards
})

filename = os.path.join(self.output_dir, "trajectories.csv")
df.to_csv(filename, index=False)


def create_trajectory_logger_from_config(logging_cfg: DictConfig) -> TrajectoryLogger:
assert logging_cfg.enabled

if logging_cfg.type == 'wandb':
return WandbTableTrajectoryLogger(max_trajectories=logging_cfg.max_trajectories)
elif logging_cfg.type == 'csv':
return CSVTrajectoryLogger(
output_dir=logging_cfg.output_dir,
max_trajectories=logging_cfg.max_trajectories
)
else:
raise ValueError(f"Unknown trajectory logger type: {logging_cfg.type}")
Loading