Skip to content

Commit 1e39eb6

Browse files
Add support for Trackio completions logging in GRPOTrainer (#4359)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 97830a3 commit 1e39eb6

File tree

4 files changed

+66
-20
lines changed

4 files changed

+66
-20
lines changed

trl/trainer/grpo_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ class GRPOConfig(TrainingArguments):
234234
235235
log_completions (`bool`, *optional*, defaults to `False`):
236236
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
237-
it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
237+
it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or
238+
`trackio`.
238239
num_completions_to_print (`int`, *optional*):
239240
Number of completions to print with `rich`. If `None`, all completions are logged.
240241
wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):

trl/trainer/grpo_trainer.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
PreTrainedTokenizerBase,
4343
ProcessorMixin,
4444
TrainerCallback,
45+
is_trackio_available,
4546
is_wandb_available,
4647
)
4748
from transformers.trainer_utils import seed_worker
@@ -94,6 +95,9 @@
9495
if is_wandb_available():
9596
import wandb
9697

98+
if is_trackio_available():
99+
import trackio
100+
97101

98102
logger = logging.get_logger(__name__)
99103

@@ -1869,7 +1873,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
18691873
self.num_completions_to_print,
18701874
)
18711875

1876+
logging_backends = []
18721877
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
1878+
logging_backends.append(wandb)
1879+
if self.args.report_to and "trackio" in self.args.report_to:
1880+
logging_backends.append(trackio)
1881+
1882+
if logging_backends:
18731883
import pandas as pd
18741884

18751885
table = {
@@ -1880,16 +1890,28 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
18801890
"advantage": self._logs["advantages"],
18811891
}
18821892

1883-
if self._logs["images"]:
1884-
table["images"] = []
1885-
for image_list in self._logs["images"]:
1886-
# Convert images to wandb Image objects for proper visualization
1887-
table["images"].append([wandb.Image(image) for image in image_list])
1893+
df_base = pd.DataFrame(table)
1894+
images_raw = self._logs["images"] or []
1895+
1896+
for logging_backend in logging_backends:
1897+
if images_raw:
1898+
# Convert images per backend and derive a dataframe that shares base columns
1899+
if logging_backend is wandb:
1900+
images = []
1901+
for image_list in self._logs["images"]:
1902+
images.append([wandb.Image(image) for image in image_list])
1903+
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
1904+
elif logging_backend is trackio:
1905+
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
1906+
logger.info("Skipping image logging for Trackio")
1907+
df = df_base
1908+
else:
1909+
df = df_base
1910+
1911+
if self.wandb_log_unique_prompts:
1912+
df = df.drop_duplicates(subset=["prompt"])
18881913

1889-
df = pd.DataFrame(table)
1890-
if self.wandb_log_unique_prompts:
1891-
df = df.drop_duplicates(subset=["prompt"])
1892-
wandb.log({"completions": wandb.Table(dataframe=df)})
1914+
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
18931915

18941916
# Ensure the model card is saved along with the checkpoint
18951917
def _save_checkpoint(self, model, trial):

trl/trainer/rloo_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ class RLOOConfig(TrainingArguments):
186186
187187
log_completions (`bool`, *optional*, defaults to `False`):
188188
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
189-
it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
189+
it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or
190+
`trackio`.
190191
num_completions_to_print (`int`, *optional*):
191192
Number of completions to print with `rich`. If `None`, all completions are logged.
192193
wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):

trl/trainer/rloo_trainer.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
PreTrainedTokenizerBase,
4343
ProcessorMixin,
4444
TrainerCallback,
45+
is_trackio_available,
4546
is_wandb_available,
4647
)
4748
from transformers.trainer_utils import seed_worker
@@ -90,6 +91,9 @@
9091
if is_wandb_available():
9192
import wandb
9293

94+
if is_trackio_available():
95+
import trackio
96+
9397

9498
logger = logging.get_logger(__name__)
9599

@@ -1511,7 +1515,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
15111515
self.num_completions_to_print,
15121516
)
15131517

1518+
logging_backends = []
15141519
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
1520+
logging_backends.append(wandb)
1521+
if self.args.report_to and "trackio" in self.args.report_to:
1522+
logging_backends.append(trackio)
1523+
1524+
if logging_backends:
15151525
import pandas as pd
15161526

15171527
table = {
@@ -1522,16 +1532,28 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
15221532
"advantage": self._logs["advantages"],
15231533
}
15241534

1525-
if self._logs["images"]:
1526-
table["images"] = []
1527-
for image_list in self._logs["images"]:
1528-
# Convert images to wandb Image objects for proper visualization
1529-
table["images"].append([wandb.Image(image) for image in image_list])
1535+
df_base = pd.DataFrame(table)
1536+
images_raw = self._logs["images"] or []
1537+
1538+
for logging_backend in logging_backends:
1539+
if images_raw:
1540+
# Convert images per backend and derive a dataframe that shares base columns
1541+
if logging_backend is wandb:
1542+
images = []
1543+
for image_list in self._logs["images"]:
1544+
images.append([wandb.Image(image) for image in image_list])
1545+
df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False)
1546+
elif logging_backend is trackio:
1547+
# TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327
1548+
logger.info("Skipping image logging for Trackio")
1549+
df = df_base
1550+
else:
1551+
df = df_base
1552+
1553+
if self.wandb_log_unique_prompts:
1554+
df = df.drop_duplicates(subset=["prompt"])
15301555

1531-
df = pd.DataFrame(table)
1532-
if self.wandb_log_unique_prompts:
1533-
df = df.drop_duplicates(subset=["prompt"])
1534-
wandb.log({"completions": wandb.Table(dataframe=df)})
1556+
logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
15351557

15361558
# Ensure the model card is saved along with the checkpoint
15371559
def _save_checkpoint(self, model, trial):

0 commit comments

Comments
 (0)