Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ This constant is recommended to be the maximum completion length. To use this fo
While training and evaluating, we record the following reward metrics:

- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
- `step_time`: The average time (in seconds) taken per training step (including generation).
- `completions/mean_length`: The average length of generated completions.
- `completions/min_length`: The minimum length of generated completions.
- `completions/max_length`: The maximum length of generated completions.
Expand Down
1 change: 1 addition & 0 deletions docs/source/rloo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ In a fully online, single-step setting (default), \\( \frac{\pi_\theta(o_i \mid
While training and evaluating, we record the following reward metrics:

- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
- `step_time`: The average time (in seconds) taken per training step (including generation).
- `completions/mean_length`: The average length of generated completions.
- `completions/min_length`: The minimum length of generated completions.
- `completions/max_length`: The maximum length of generated completions.
Expand Down
14 changes: 13 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import os
import textwrap
import time
import warnings
from collections import defaultdict, deque
from collections.abc import Callable
Expand Down Expand Up @@ -531,6 +532,7 @@ def cast_outputs_to_original_dtype(module, args, output):
# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._total_train_tokens = 0
self._current_train_step_time = 0.0
self.log_completions = args.log_completions
self.log_unique_prompts = args.log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
Expand Down Expand Up @@ -1044,6 +1046,17 @@ def _move_model_to_vllm(self):
elif self.vllm_mode == "colocate":
self.llm.reset_prefix_cache()

def training_step(self, model, inputs, num_items_in_batch):
time_before = time.perf_counter()
output = super().training_step(model, inputs, num_items_in_batch)
self._step += 1
time_after = time.perf_counter()
self._current_train_step_time += time_after - time_before
if self._step % self.current_gradient_accumulation_steps == 0:
self._metrics["train"]["step_time"].append(self._current_train_step_time)
self._current_train_step_time = 0.0
return output

@profiling_decorator
def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
# Prepares inputs for model training/evaluation by managing completion generation and batch handling.
Expand All @@ -1070,7 +1083,6 @@ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> di
generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
self._step += 1
else:
# In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
# local generation batch == local eval batch
Expand Down
13 changes: 13 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import os
import textwrap
import time
import warnings
from collections import defaultdict, deque
from collections.abc import Callable
Expand Down Expand Up @@ -448,6 +449,7 @@ def __init__(
# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._total_train_tokens = 0
self._current_train_step_time = 0.0
self.log_completions = args.log_completions
self.log_unique_prompts = args.log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
Expand Down Expand Up @@ -878,6 +880,17 @@ def _move_model_to_vllm(self):
elif self.vllm_mode == "colocate":
self.llm.reset_prefix_cache()

def training_step(self, model, inputs, num_items_in_batch):
time_before = time.perf_counter()
output = super().training_step(model, inputs, num_items_in_batch)
self._step += 1
time_after = time.perf_counter()
self._current_train_step_time += time_after - time_before
if self._step % self.current_gradient_accumulation_steps == 0:
self._metrics["train"]["step_time"].append(self._current_train_step_time)
self._current_train_step_time = 0.0
return output

@profiling_decorator
def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
# Prepares inputs for model training/evaluation by managing completion generation and batch handling.
Expand Down
Loading