Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

☄️ Add support for Comet experiment management SDK integration #2462

Merged
merged 18 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f4caa54
Added support for Comet URL integration into model cards created by t…
yaricom Dec 11, 2024
775478d
Moved `get_comet_experiment_url()` into utils.py
yaricom Dec 11, 2024
adc024f
Updated Comet badge in the model card to use PNG image instead of text.
yaricom Dec 11, 2024
cd923ed
Merge branch 'main' into add-comet_ml-integration
yaricom Dec 11, 2024
d4bac7f
Fixed bug related to running PPO example during model saving. The err…
yaricom Dec 11, 2024
7cfedfe
Merge branch 'main' into add-comet_ml-integration
yaricom Dec 12, 2024
9085de1
Implemented utility method to handle logging of tabular data to the C…
yaricom Dec 12, 2024
8e4482d
Implemented logging of the completions table to Comet by `PPOTrainer`.
yaricom Dec 12, 2024
233f164
Implemented logging of the completions table to Comet by `WinRateCall…
yaricom Dec 12, 2024
4611889
Implemented logging of the completions table to Comet by `RLOOTrainer…
yaricom Dec 12, 2024
b9869a7
Restored line to the main branch version.
yaricom Dec 12, 2024
c920c98
Merge branch 'main' into add-comet_ml-integration
yaricom Dec 13, 2024
5a0222e
Moved Comet related utility methods into `trainer/utils.py` to resolv…
yaricom Dec 13, 2024
7b74c08
Merge branch 'main' into add-comet_ml-integration
yaricom Dec 13, 2024
5928a14
Update trl/trainer/utils.py
yaricom Dec 13, 2024
048b06d
Implemented raising of `ModuleNotFoundError` error when logging table…
yaricom Dec 13, 2024
da22de6
Merge branch 'main' into add-comet_ml-integration
qgallouedec Dec 13, 2024
bb466c1
import comet with other imports
qgallouedec Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def test_full(self):
dataset_name="username/my_dataset",
tags=["trl", "trainer-tag"],
wandb_url="https://wandb.ai/username/project_id/runs/abcd1234",
comet_url="https://www.comet.com/username/project_id/experiment_id",
trainer_name="My Trainer",
trainer_citation="@article{my_trainer, ...}",
paper_title="My Paper",
Expand All @@ -158,6 +159,7 @@ def test_full(self):
self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text)
self.assertIn("datasets: username/my_dataset", card_text)
self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text)
self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text)
self.assertIn("My Trainer", card_text)
self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text)
self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text)
Expand All @@ -170,6 +172,7 @@ def test_val_none(self):
dataset_name=None,
tags=[],
wandb_url=None,
comet_url=None,
trainer_name="My Trainer",
trainer_citation=None,
paper_title=None,
Expand Down
3 changes: 2 additions & 1 deletion trl/templates/lm_model_card.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ print(output["generated_text"])

## Training procedure

{% if wandb_url %}[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>]({{ wandb_url }}){% endif %}
{% if wandb_url %}[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>]({{ wandb_url }}){% endif %}
{% if comet_url %}[<img src="https://raw.githubusercontent.com/comet-ml/comet-examples/master/logo/comet_badge.png" alt="Visualize in Comet" width="135" height="20"/>]({{ comet_url }}){% endif %}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make an svg instead? Or at least increase the quality?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll ask GFX designer to fix this.


This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.

Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/alignprop_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ..models import DDPOStableDiffusionPipeline
from . import AlignPropConfig, BaseTrainer
from .utils import generate_model_card
from .utils import generate_model_card, get_comet_experiment_url


if is_wandb_available():
Expand Down Expand Up @@ -438,6 +438,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="AlignProp",
trainer_citation=citation,
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
RunningMoments,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1514,6 +1515,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="BCO",
trainer_citation=citation,
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
Expand Down
59 changes: 46 additions & 13 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import os
from typing import Optional, Union
from typing import List, Optional, Union

import pandas as pd
import torch
Expand Down Expand Up @@ -42,6 +42,7 @@
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
from .utils import log_table_to_comet_experiment


if is_deepspeed_available():
Expand Down Expand Up @@ -199,6 +200,16 @@ def on_train_end(self, args, state, control, **kwargs):
self.current_step = None


def _win_rate_completions_df(
state: TrainerState, prompts: List[str], completions: List[str], winner_indices: List[str]
) -> pd.DataFrame:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
# Split completions from reference model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"])


class WinRateCallback(TrainerCallback):
"""
A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference.
Expand Down Expand Up @@ -311,15 +322,26 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
import wandb

if wandb.run is not None:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
# Split completions from referenece model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
df = pd.DataFrame(
split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})

if "comet_ml" in args.report_to:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
log_table_to_comet_experiment(
name="win_rate_completions.csv",
table=df,
)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# At every evaluation step, we generate completions for the model and compare them with the reference
# completions that have been generated at the beginning of training. We then compute the win rate and log it to
Expand Down Expand Up @@ -363,15 +385,26 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
import wandb

if wandb.run is not None:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
# Split completions from referenece model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
df = pd.DataFrame(
split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
wandb.log({"win_rate_completions": wandb.Table(dataframe=df)})

if "comet_ml" in args.report_to:
df = _win_rate_completions_df(
state=state,
prompts=prompts,
completions=completions,
winner_indices=winner_indices,
)
log_table_to_comet_experiment(
name="win_rate_completions.csv",
table=df,
)


class LogCompletionsCallback(WandbCallback):
r"""
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1052,6 +1053,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="CPO",
trainer_citation=citation,
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ..models import DDPOStableDiffusionPipeline
from . import BaseTrainer, DDPOConfig
from .utils import PerPromptStatTracker, generate_model_card
from .utils import PerPromptStatTracker, generate_model_card, get_comet_experiment_url


if is_wandb_available():
Expand Down Expand Up @@ -641,6 +641,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="DDPO",
trainer_citation=citation,
paper_title="Training Diffusion Models with Reinforcement Learning",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
cap_exp,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -1483,6 +1484,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="DPO",
trainer_citation=citation,
paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
Expand Down
9 changes: 8 additions & 1 deletion trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
from ..models.utils import unwrap_model_for_generation
from .gkd_config import GKDConfig
from .sft_trainer import SFTTrainer
from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache, generate_model_card
from .utils import (
DataCollatorForChatML,
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_comet_experiment_url,
)


if is_deepspeed_available():
Expand Down Expand Up @@ -378,6 +384,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="GKD",
trainer_citation=citation,
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from transformers.utils import is_peft_available

from ..core import PPODecorators
from .utils import generate_model_card
from .utils import generate_model_card, get_comet_experiment_url


if is_peft_available():
Expand Down Expand Up @@ -434,6 +434,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Iterative SFT",
)

Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
DPODataCollatorWithPadding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1526,6 +1527,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="KTO",
trainer_citation=citation,
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
Expand Down
10 changes: 9 additions & 1 deletion trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
from .judges import BasePairwiseJudge
from .nash_md_config import NashMDConfig
from .online_dpo_trainer import OnlineDPOTrainer
from .utils import SIMPLE_CHAT_TEMPLATE, empty_cache, generate_model_card, get_reward, truncate_right
from .utils import (
SIMPLE_CHAT_TEMPLATE,
empty_cache,
generate_model_card,
get_comet_experiment_url,
get_reward,
truncate_right,
)


if is_apex_available():
Expand Down Expand Up @@ -500,6 +507,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Nash-MD",
trainer_citation=citation,
paper_title="Nash Learning from Human Feedback",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_comet_experiment_url,
get_reward,
prepare_deepspeed,
truncate_right,
Expand Down Expand Up @@ -734,6 +735,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Online DPO",
trainer_citation=citation,
paper_title="Direct Language Model Alignment from Online AI Feedback",
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
pad_to_length,
peft_module_casting_to_bf16,
)
Expand Down Expand Up @@ -1077,6 +1078,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="ORPO",
trainer_citation=citation,
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
first_true_indices,
forward,
generate_model_card,
get_comet_experiment_url,
get_reward,
log_table_to_comet_experiment,
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
Expand Down Expand Up @@ -727,6 +729,12 @@ def generate_completions(self, sampling: bool = False):
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})

if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -774,6 +782,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="PPO",
trainer_citation=citation,
paper_title="Fine-Tuning Language Models from Human Preferences",
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
compute_accuracy,
decode_and_strip_padding,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
print_rich_table,
)

Expand Down Expand Up @@ -359,6 +361,12 @@ def visualize_samples(self, num_print_samples: int):
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})

if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down Expand Up @@ -398,6 +406,7 @@ def create_model_card(
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Reward",
)

Expand Down
Loading
Loading