Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/test_rich_progress_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,3 @@ def test_rich_progress_callback_logging(self):
)

trainer.train()
trainer.train()
131 changes: 82 additions & 49 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@


if is_rich_available():
from rich.columns import Columns
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from rich.table import Table

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -152,74 +154,105 @@ def __init__(self):
raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.")

self.training_bar = None
self.prediction_bar = None

self.training_task_id = None
self.prediction_task_id = None

self.evaluation_bar = None
self.training_task = None
self.evaluation_task = None
self.rich_group = None
self.rich_console = None

self.training_status = None
self.current_step = None

def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.training_bar = Progress()
self.prediction_bar = Progress()

self.rich_console = Console()

self.training_status = self.rich_console.status("Nothing to log yet ...")

self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status)))
self.rich_group.start()
if not state.is_world_process_zero:
return

self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps)
self.current_step = 0
self.training_bar = Progress()
self.evaluation_bar = Progress()
self.rich_console = Console()
self.training_status = self.rich_console.status("Nothing to log yet ...")
self.rich_group = Live(Panel(Group(self.training_bar, self.evaluation_bar, self.training_status)))
self.rich_group.start()
self.training_task = self.training_bar.add_task("[blue]Training ", total=state.max_steps)
self.current_step = 0

def on_step_end(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True)
self.current_step = state.global_step
if not state.is_world_process_zero:
return

self.training_bar.update(self.training_task, advance=state.global_step - self.current_step, update=True)
self.current_step = state.global_step

def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_world_process_zero and has_length(eval_dataloader):
if self.prediction_task_id is None:
self.prediction_task_id = self.prediction_bar.add_task(
"[blue]Predicting on the evaluation dataset", total=len(eval_dataloader)
)
self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
if not state.is_world_process_zero:
return

if has_length(eval_dataloader):
if self.evaluation_task is None:
self.evaluation_task = self.evaluation_bar.add_task("[blue]Evaluation", total=len(eval_dataloader))
self.evaluation_bar.update(self.evaluation_task, advance=1, update=True)

def on_evaluate(self, args, state, control, **kwargs):
if state.is_world_process_zero:
if self.prediction_task_id is not None:
self.prediction_bar.remove_task(self.prediction_task_id)
self.prediction_task_id = None
if not state.is_world_process_zero:
return

if self.evaluation_task is not None:
self.evaluation_bar.remove_task(self.evaluation_task)
self.evaluation_task = None

def on_predict(self, args, state, control, **kwargs):
if state.is_world_process_zero:
if self.prediction_task_id is not None:
self.prediction_bar.remove_task(self.prediction_task_id)
self.prediction_task_id = None
if not state.is_world_process_zero:
return

if self.evaluation_task is not None:
self.evaluation_bar.remove_task(self.evaluation_task)
self.evaluation_task = None

def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero and self.training_bar is not None:
_ = logs.pop("total_flos", None)
self.training_status.update(f"[bold green]Status = {str(logs)}")
if not (state.is_world_process_zero and self.training_bar):
return

# Group keys by top-level prefix
grouped_logs = {}
for key, value in logs.items():
parts = key.split("/")
group = parts[0] if len(parts) > 1 else None
subkey = "/".join(parts[1:]) if len(parts) > 1 else key
grouped_logs.setdefault(group, {})[subkey] = value

# Create a table per group
tables = []
for group_name, metrics in grouped_logs.items():
table = Table(
title=f"[bold blue]{group_name}[/]" if group_name else None, header_style="bold magenta", box=None
)
table.add_column("Metric", justify="left", no_wrap=True)
table.add_column("Value", justify="right")

for metric, val in metrics.items():
formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val)
table.add_row(metric, formatted)

tables.append(Panel(table, border_style="cyan", padding=(0, 1)))

# Arrange tables in columns using Columns
column_layout = Columns(tables, equal=False, expand=True)
self.training_status.update(
Panel(column_layout, title=f"[bold green]Step {state.global_step}[/bold green]", border_style="green")
)

def on_train_end(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.rich_group.stop()

self.training_bar = None
self.prediction_bar = None
self.training_task_id = None
self.prediction_task_id = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None
if not state.is_world_process_zero:
return

self.rich_group.stop()
self.training_bar = None
self.evaluation_bar = None
self.training_task = None
self.evaluation_task = None
self.rich_group = None
self.rich_console = None
self.training_status = None
self.current_step = None


def _win_rate_completions_df(
Expand Down
Loading