From a3eb7666ce88a49bc60f4c9daeb35be5ccbb9981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 12 Nov 2025 22:41:26 +0000 Subject: [PATCH 1/4] Add step time metric to GRPO Trainer for performance tracking --- docs/source/grpo_trainer.md | 1 + trl/trainer/grpo_trainer.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index a3d99953706..bdc132e4115 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -149,6 +149,7 @@ This constant is recommended to be the maximum completion length. To use this fo While training and evaluating, we record the following reward metrics: - `num_tokens`: The total number of tokens processed so far, including both prompts and completions. +- `step_time`: The average time (in seconds) taken per training step (including generation). - `completions/mean_length`: The average length of generated completions. - `completions/min_length`: The minimum length of generated completions. - `completions/max_length`: The maximum length of generated completions. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6ce2cba52b0..52d60fcd276 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -15,6 +15,7 @@ import inspect import os import textwrap +import time import warnings from collections import defaultdict, deque from collections.abc import Callable @@ -531,6 +532,7 @@ def cast_outputs_to_original_dtype(module, args, output): # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._total_train_tokens = 0 + self._current_train_step_time = 0.0 self.log_completions = args.log_completions self.wandb_log_unique_prompts = args.wandb_log_unique_prompts self.num_completions_to_print = args.num_completions_to_print @@ -1044,6 +1046,17 @@ def _move_model_to_vllm(self): elif self.vllm_mode == "colocate": self.llm.reset_prefix_cache() + def training_step(self, model, inputs, num_items_in_batch): + time_before = time.time() + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + time_after = time.time() + self._current_train_step_time += time_after - time_before + if self._step % self.current_gradient_accumulation_steps == 0: + self._metrics["train"]["step_time"].append(self._current_train_step_time) + self._current_train_step_time = 0.0 + return output + @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: # Prepares inputs for model training/evaluation by managing completion generation and batch handling. @@ -1070,7 +1083,6 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] - self._step += 1 else: # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence # local generation batch == local eval batch From 2e1339c46bd5b57bb52415a1f1cfc362adeb5137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:18:26 -0700 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: lewtun --- trl/trainer/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7966e2f0eb2..b2517eb7903 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1047,10 +1047,10 @@ def _move_model_to_vllm(self): self.llm.reset_prefix_cache() def training_step(self, model, inputs, num_items_in_batch): - time_before = time.time() + time_before = time.perf_counter() output = super().training_step(model, inputs, num_items_in_batch) self._step += 1 - time_after = time.time() + time_after = time.perf_counter() self._current_train_step_time += time_after - time_before if self._step % self.current_gradient_accumulation_steps == 0: self._metrics["train"]["step_time"].append(self._current_train_step_time) From 0d7b2501d2bfae6eb9ca03cc682dac7ca1622882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 13 Nov 2025 17:22:18 +0000 Subject: [PATCH 3/4] Add training step timing metrics to RLOOTrainer --- docs/source/rloo_trainer.md | 1 + trl/trainer/rloo_trainer.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 36d315e678d..1b8089337a9 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -135,6 +135,7 @@ In a fully online, single-step setting (default), \\( \frac{\pi_\theta(o_i \mid While training and evaluating, we record the following reward metrics: - `num_tokens`: The total number of tokens processed so far, including both prompts and completions. +- `step_time`: The average time (in seconds) taken per training step (including generation). - `completions/mean_length`: The average length of generated completions. - `completions/min_length`: The minimum length of generated completions. - `completions/max_length`: The maximum length of generated completions. diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 7203862fa1d..34bf82508a9 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -22,6 +22,7 @@ from functools import partial from pathlib import Path from typing import Any +import time import datasets import pandas as pd @@ -448,6 +449,7 @@ def __init__( # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._total_train_tokens = 0 + self._current_train_step_time = 0.0 self.log_completions = args.log_completions self.log_unique_prompts = args.log_unique_prompts self.num_completions_to_print = args.num_completions_to_print @@ -878,6 +880,17 @@ def _move_model_to_vllm(self): elif self.vllm_mode == "colocate": self.llm.reset_prefix_cache() + def training_step(self, model, inputs, num_items_in_batch): + time_before = time.perf_counter() + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + time_after = time.perf_counter() + self._current_train_step_time += time_after - time_before + if self._step % self.current_gradient_accumulation_steps == 0: + self._metrics["train"]["step_time"].append(self._current_train_step_time) + self._current_train_step_time = 0.0 + return output + @profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: # Prepares inputs for model training/evaluation by managing completion generation and batch handling. From 8c71ec127ceb98f706ebdf8cb9bd8bcc165fa0cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 13 Nov 2025 17:22:53 +0000 Subject: [PATCH 4/4] style --- trl/trainer/rloo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 34bf82508a9..2ee731dbd65 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -15,6 +15,7 @@ import inspect import os import textwrap +import time import warnings from collections import defaultdict, deque from collections.abc import Callable @@ -22,7 +23,6 @@ from functools import partial from pathlib import Path from typing import Any -import time import datasets import pandas as pd