diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index f6f3c6e3462..e6299927ab6 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -234,7 +234,8 @@ class GRPOConfig(TrainingArguments): log_completions (`bool`, *optional*, defaults to `False`): Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, - it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. + it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or + `trackio`. num_completions_to_print (`int`, *optional*): Number of completions to print with `rich`. If `None`, all completions are logged. wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 267b64b6e3b..4631e5effb9 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -42,6 +42,7 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_trackio_available, is_wandb_available, ) from transformers.trainer_utils import seed_worker @@ -94,6 +95,9 @@ if is_wandb_available(): import wandb +if is_trackio_available(): + import trackio + logger = logging.get_logger(__name__) @@ -1869,7 +1873,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non self.num_completions_to_print, ) + logging_backends = [] if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + logging_backends.append(wandb) + if self.args.report_to and "trackio" in self.args.report_to: + logging_backends.append(trackio) + + if logging_backends: import pandas as pd table = { @@ -1880,16 +1890,28 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["images"]: - table["images"] = [] - for image_list in self._logs["images"]: - # Convert images to wandb Image objects for proper visualization - table["images"].append([wandb.Image(image) for image in image_list]) + df_base = pd.DataFrame(table) + images_raw = self._logs["images"] or [] + + for logging_backend in logging_backends: + if images_raw: + # Convert images per backend and derive a dataframe that shares base columns + if logging_backend is wandb: + images = [] + for image_list in self._logs["images"]: + images.append([wandb.Image(image) for image in image_list]) + df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False) + elif logging_backend is trackio: + # TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327 + logger.info("Skipping image logging for Trackio") + df = df_base + else: + df = df_base + + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) - df = pd.DataFrame(table) - if self.wandb_log_unique_prompts: - df = df.drop_duplicates(subset=["prompt"]) - wandb.log({"completions": wandb.Table(dataframe=df)}) + logging_backend.log({"completions": logging_backend.Table(dataframe=df)}) # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial): diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 335cc72fd09..28ee1b6fb46 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -186,7 +186,8 @@ class RLOOConfig(TrainingArguments): log_completions (`bool`, *optional*, defaults to `False`): Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, - it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. + it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or + `trackio`. num_completions_to_print (`int`, *optional*): Number of completions to print with `rich`. If `None`, all completions are logged. wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 3ed56624e83..d108c8132d5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -42,6 +42,7 @@ PreTrainedTokenizerBase, ProcessorMixin, TrainerCallback, + is_trackio_available, is_wandb_available, ) from transformers.trainer_utils import seed_worker @@ -90,6 +91,9 @@ if is_wandb_available(): import wandb +if is_trackio_available(): + import trackio + logger = logging.get_logger(__name__) @@ -1511,7 +1515,13 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non self.num_completions_to_print, ) + logging_backends = [] if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + logging_backends.append(wandb) + if self.args.report_to and "trackio" in self.args.report_to: + logging_backends.append(trackio) + + if logging_backends: import pandas as pd table = { @@ -1522,16 +1532,28 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["images"]: - table["images"] = [] - for image_list in self._logs["images"]: - # Convert images to wandb Image objects for proper visualization - table["images"].append([wandb.Image(image) for image in image_list]) + df_base = pd.DataFrame(table) + images_raw = self._logs["images"] or [] + + for logging_backend in logging_backends: + if images_raw: + # Convert images per backend and derive a dataframe that shares base columns + if logging_backend is wandb: + images = [] + for image_list in self._logs["images"]: + images.append([wandb.Image(image) for image in image_list]) + df = pd.concat([df_base, pd.Series(images, name="image")], axis=1, copy=False) + elif logging_backend is trackio: + # TODO: Implement once supported upstream https://github.com/gradio-app/trackio/issues/327 + logger.info("Skipping image logging for Trackio") + df = df_base + else: + df = df_base + + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) - df = pd.DataFrame(table) - if self.wandb_log_unique_prompts: - df = df.drop_duplicates(subset=["prompt"]) - wandb.log({"completions": wandb.Table(dataframe=df)}) + logging_backend.log({"completions": logging_backend.Table(dataframe=df)}) # Ensure the model card is saved along with the checkpoint def _save_checkpoint(self, model, trial):