Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ class GRPOConfig(TrainingArguments):
log_completions (`bool`, *optional*, defaults to `False`):
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or
`trackio`.
num_completions_to_print (`int`, *optional*):
Number of completions to print with `rich`. If `None`, all completions are logged.
wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):
Expand Down
40 changes: 31 additions & 9 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
is_trackio_available,
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
Expand Down Expand Up @@ -94,6 +95,9 @@
if is_wandb_available():
import wandb

if is_trackio_available():
import trackio


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -1869,7 +1873,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
self.num_completions_to_print,
)

logging_backends = []
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
logging_backends.append(wandb)
if self.args.report_to and "trackio" in self.args.report_to:
logging_backends.append(trackio)

if logging_backends:
import pandas as pd

table = {
Expand All @@ -1880,16 +1890,28 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
"advantage": self._logs["advantages"],
}

if self._logs["images"]:
table["images"] = []
for image_list in self._logs["images"]:
# Convert images to wandb Image objects for proper visualization
table["images"].append([wandb.Image(image) for image in image_list])
df_base = pd.DataFrame(table)
images_raw = self._logs["images"] or []

for logging_backend in logging_backends:
if images_raw:
# Convert images per backend and derive a dataframe that shares base columns
if logging_backend is wandb:
images = []
for image_list in self._logs["images"]:
images.append([wandb.Image(image) for image in image_list])
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
elif logging_backend is trackio:
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
logger.info("Skipping image logging for Trackio")
df = df_base
else:
df = df_base

if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])

df = pd.DataFrame(table)
if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
wandb.log({"completions": wandb.Table(dataframe=df)})
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})

# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ class RLOOConfig(TrainingArguments):
log_completions (`bool`, *optional*, defaults to `False`):
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or
`trackio`.
num_completions_to_print (`int`, *optional*):
Number of completions to print with `rich`. If `None`, all completions are logged.
wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):
Expand Down
40 changes: 31 additions & 9 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
is_trackio_available,
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
Expand Down Expand Up @@ -90,6 +91,9 @@
if is_wandb_available():
import wandb

if is_trackio_available():
import trackio


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -1511,7 +1515,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
self.num_completions_to_print,
)

logging_backends = []
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
logging_backends.append(wandb)
if self.args.report_to and "trackio" in self.args.report_to:
logging_backends.append(trackio)

if logging_backends:
import pandas as pd

table = {
Expand All @@ -1522,16 +1532,28 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
"advantage": self._logs["advantages"],
}

if self._logs["images"]:
table["images"] = []
for image_list in self._logs["images"]:
# Convert images to wandb Image objects for proper visualization
table["images"].append([wandb.Image(image) for image in image_list])
df_base = pd.DataFrame(table)
images_raw = self._logs["images"] or []

for logging_backend in logging_backends:
if images_raw:
# Convert images per backend and derive a dataframe that shares base columns
if logging_backend is wandb:
images = []
for image_list in self._logs["images"]:
images.append([wandb.Image(image) for image in image_list])
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
elif logging_backend is trackio:
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
logger.info("Skipping image logging for Trackio")
df = df_base
else:
df = df_base

if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])

df = pd.DataFrame(table)
if self.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
wandb.log({"completions": wandb.Table(dataframe=df)})
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})

# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):
Expand Down
Loading