-
Notifications
You must be signed in to change notification settings - Fork 260
Add generic trajectory logging for debugging RL training #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8021b62
094ef99
3e2ba62
2b32ba2
76c3742
7bc42dc
c21db34
179699f
9884e4a
3beca95
0387f27
acf7f93
02a3f03
57cdb6d
0210e44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| """ | ||
| Trajectory logging utilities for debugging and analysis during training. | ||
|
|
||
| This module provides a simple framework for logging prompts, responses, and rewards. | ||
| """ | ||
|
|
||
| import os | ||
| from abc import ABC, abstractmethod | ||
| from typing import List, Optional | ||
| import pandas as pd | ||
| from omegaconf import DictConfig | ||
|
|
||
|
|
||
| class TrajectoryLogger(ABC): | ||
| """ | ||
| Abstract base class for trajectory logging. | ||
|
|
||
| TODO: Allow users to bring a custom trajectory logger. They should be able to | ||
| define their own class (outside of the skyrl-train package) and add it to | ||
| a registry (see AdvantageEstimatorRegistry for an example) so that it can | ||
| be referenced by name in the config. | ||
| """ | ||
|
Comment on lines
+14
to
+22
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To improve maintainability and reduce code duplication, the trajectory slicing logic, which is currently duplicated in both You could use the template method pattern by making
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lines of duplicated slicing logic are worth it for the much cleaner user experience when implementing custom loggers |
||
|
|
||
| def __init__(self, max_trajectories: int = -1): | ||
| self.max_trajectories = max_trajectories | ||
|
|
||
| @abstractmethod | ||
| def log( | ||
| self, | ||
| prompts: List[str], | ||
| responses: List[str], | ||
| rewards: List[float] | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| class WandbTableTrajectoryLogger(TrajectoryLogger): | ||
|
|
||
| def __init__(self, max_trajectories: int = -1): | ||
| super().__init__(max_trajectories) | ||
| import wandb | ||
| self.wandb = wandb | ||
|
|
||
| def log( | ||
| self, | ||
| prompts: List[str], | ||
| responses: List[str], | ||
| rewards: List[float] | ||
| ) -> None: | ||
| if self.max_trajectories > 0: | ||
| prompts = prompts[:self.max_trajectories] | ||
| responses = responses[:self.max_trajectories] | ||
| rewards = rewards[:self.max_trajectories] | ||
|
|
||
| data = [[i, prompt, response, reward] for i, (prompt, response, reward) in enumerate(zip(prompts, responses, rewards))] | ||
| table = self.wandb.Table(columns=["step", "prompt", "response", "reward"], data=data) | ||
|
|
||
| self.wandb.log({"trajectories": table}) | ||
|
|
||
|
|
||
| class CSVTrajectoryLogger(TrajectoryLogger): | ||
|
|
||
| def __init__(self, output_dir: str, max_trajectories: int = -1): | ||
| super().__init__(max_trajectories) | ||
| self.output_dir = output_dir | ||
| os.makedirs(output_dir, exist_ok=True) | ||
|
|
||
| def log( | ||
| self, | ||
| prompts: List[str], | ||
| responses: List[str], | ||
| rewards: List[float] | ||
| ) -> None: | ||
| if self.max_trajectories > 0: | ||
| prompts = prompts[:self.max_trajectories] | ||
| responses = responses[:self.max_trajectories] | ||
| rewards = rewards[:self.max_trajectories] | ||
|
|
||
| df = pd.DataFrame({ | ||
| "step": list(range(len(prompts))), | ||
| "prompt": prompts, | ||
| "response": responses, | ||
| "reward": rewards | ||
| }) | ||
|
|
||
| filename = os.path.join(self.output_dir, "trajectories.csv") | ||
| df.to_csv(filename, index=False) | ||
|
|
||
|
|
||
| def create_trajectory_logger_from_config(logging_cfg: DictConfig) -> TrajectoryLogger: | ||
| assert logging_cfg.enabled | ||
|
|
||
| if logging_cfg.type == 'wandb': | ||
| return WandbTableTrajectoryLogger(max_trajectories=logging_cfg.max_trajectories) | ||
| elif logging_cfg.type == 'csv': | ||
| return CSVTrajectoryLogger( | ||
| output_dir=logging_cfg.output_dir, | ||
| max_trajectories=logging_cfg.max_trajectories | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown trajectory logger type: {logging_cfg.type}") | ||
Uh oh!
There was an error while loading. Please reload this page.