Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPO] Resolve logging for DPOTrainer #570

Merged
merged 3 commits into from
Jul 26, 2023

Conversation

tomaarsen
Copy link
Member

Resolves #569, resolves #568

Hello!

Pull Request overview

  • Resolve logging for the DPOTrainer to:
    • prevent spamminess,
    • conform to TrainingArguments parameters like logging_steps,
    • propagate logs to callbacks, e.g. W&B and TensorBoard

Details

In the transformers Trainer, a log method is called in a handful of different occasions, but primarily here:
https://github.com/huggingface/transformers/blob/ee1eb3b325ce360bbd6c910c1402bca9dfb418f9/src/transformers/trainer.py#L2188-L2208
This method is essentially always called with a very specific dictionary, e.g. one with only a loss and learning_rate key. These cannot easily be extended without overriding large methods like _maybe_log_save_evaluate or _inner_training_loop in which these calls occur.
So, we can't easily update the values with which log is initially called.

Instead, we can override log itself, and insert the means of some stored metrics directly in that method. That is what this PR does. In particular, I create a store_metrics method, and call it whenever log_metrics used to forcibly be called, and I override log with a middle-man as mentioned.

Limitations

One of the primary annoyances with this implementation is that the output metrics of dpo_trainer.evaluate(metric_key_prefix="test") will start with eval, e.g.:

{'test_loss': 0.6931471228599548, 'test_runtime': 14.3839, 'test_samples_per_second': 6.952, 'test_steps_per_second': 0.904, 'eval_rewards/chosen': 0.0, 'eval_rewards/rejected': 0.0, 'eval_rewards/accuracies': 0.0, 'eval_rewards/margins': 0.0, 'eval_logps/rejected': -138.52774047851562, 'eval_logps/chosen': -112.80802154541016, 'eval_logits/rejected': -108.01325988769531, 'eval_logits/chosen': -107.01180267333984}

Results

I tried this example script to try it out:

Click to see training script
from typing import Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DPOTrainer
from datasets import Dataset, load_dataset

model = AutoModelForCausalLM.from_pretrained("gpt2")

model_ref = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token


def extract_anthropic_prompt(prompt_and_response):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term = "\n\nAssistant:"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """
    dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt) :],
            "rejected": sample["rejected"][len(prompt) :],
        }

    return dataset.map(split_prompt_and_responses)


train_dataset = get_hh(split="train").select(range(1000))
eval_dataset = get_hh(split="test").select(range(100))
training_args = TrainingArguments(
    per_device_train_batch_size=4,
    remove_unused_columns=False,
    output_dir="./tmp",
    report_to="wandb",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=50,
    logging_steps=20,
)

dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
dpo_trainer.train()

Which had the following terminal logs:

Click to see logs
Found cached dataset json ([path]
Loading cached processed dataset at [path]
Found cached dataset json ([path]
Loading cached processed dataset at [path]
Could not estimate the number of tokens of the input, floating-point operations will not be computed
wandb: Currently logged in as: tomaarsen. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.6
wandb: Run data is saved locally in [path]
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run lunar-elevator-1274
wandb:  View project at [link]
wandb:  View run at [link]
  0%|                                                                                                                                                            | 0/250 [00:00<?, ?it/s]
{'loss': 0.8474, 'learning_rate': 4.600000000000001e-05, 'rewards/chosen': -0.9715192914009094, 'rewards/rejected': -1.0397388935089111, 'rewards/accuracies': 0.5249999761581421, 'rewards/margins': 0.06821951270103455, 'logps/rejected': -155.26760864257812, 'logps/chosen': -129.82192993164062, 'logits/rejected': -106.0434341430664, 'logits/chosen': -106.28038024902344, 'epoch': 0.08}
{'loss': 0.9073, 'learning_rate': 4.2e-05, 'rewards/chosen': -1.17503821849823, 'rewards/rejected': -1.4742968082427979, 'rewards/accuracies': 0.5625, 'rewards/margins': 0.2992585599422455, 'logps/rejected': -148.04446411132812, 'logps/chosen': -138.34176635742188, 'logits/rejected': -109.33012390136719, 'logits/chosen': -109.74278259277344, 'epoch': 0.16}
{'eval_loss': 0.8106440901756287, 'eval_runtime': 9.5817, 'eval_samples_per_second': 10.437, 'eval_steps_per_second': 1.357, 'eval_rewards/chosen': -0.26931989192962646, 'eval_rewards/rejected': -0.19057123363018036, 'eval_rewards/accuracies': 0.45192307233810425, 'eval_rewards/margins': -0.0787486881017685, 'eval_logps/rejected': -140.43345642089844, 'eval_logps/chosen': -115.501220703125, 'eval_logits/rejected': -112.14944458007812, 'eval_logits/chosen': -112.40476989746094, 'epoch': 0.2}
{'loss': 0.8165, 'learning_rate': 3.8e-05, 'rewards/chosen': -1.1095190048217773, 'rewards/rejected': -1.5209710597991943, 'rewards/accuracies': 0.625, 'rewards/margins': 0.41145211458206177, 'logps/rejected': -175.07534790039062, 'logps/chosen': -148.30918884277344, 'logits/rejected': -108.8155517578125, 'logits/chosen': -109.02369689941406, 'epoch': 0.24}
{'loss': 0.7638, 'learning_rate': 3.4000000000000007e-05, 'rewards/chosen': -0.923618495464325, 'rewards/rejected': -1.2611225843429565, 'rewards/accuracies': 0.5249999761581421, 'rewards/margins': 0.33750391006469727, 'logps/rejected': -145.17909240722656, 'logps/chosen': -112.1527328491211, 'logits/rejected': -106.68702697753906, 'logits/chosen': -106.28450012207031, 'epoch': 0.32}
{'loss': 0.7775, 'learning_rate': 3e-05, 'rewards/chosen': -1.0363740921020508, 'rewards/rejected': -1.6949989795684814, 'rewards/accuracies': 0.625, 'rewards/margins': 0.6586247682571411, 'logps/rejected': -183.61566162109375, 'logps/chosen': -133.77381896972656, 'logits/rejected': -108.0135726928711, 'logits/chosen': -107.72100830078125, 'epoch': 0.4}
{'eval_loss': 0.7498546838760376, 'eval_runtime': 8.838, 'eval_samples_per_second': 11.315, 'eval_steps_per_second': 1.471, 'eval_rewards/chosen': -0.16537557542324066, 'eval_rewards/rejected': -0.28140151500701904, 'eval_rewards/accuracies': 0.5480769276618958, 'eval_rewards/margins': 0.11602596193552017, 'eval_logps/rejected': -141.3417510986328, 'eval_logps/chosen': -114.46178436279297, 'eval_logits/rejected': -111.23967742919922, 'eval_logits/chosen': -110.92505645751953, 'epoch': 0.4}
{'loss': 0.7083, 'learning_rate': 2.6000000000000002e-05, 'rewards/chosen': -0.7516697645187378, 'rewards/rejected': -1.2517406940460205, 'rewards/accuracies': 0.625, 'rewards/margins': 0.500071108341217, 'logps/rejected': -139.60842895507812, 'logps/chosen': -103.69940185546875, 'logits/rejected': -109.14103698730469, 'logits/chosen': -108.06363677978516, 'epoch': 0.48}
{'loss': 0.7915, 'learning_rate': 2.2000000000000003e-05, 'rewards/chosen': -0.9171980619430542, 'rewards/rejected': -1.350841760635376, 'rewards/accuracies': 0.612500011920929, 'rewards/margins': 0.43364372849464417, 'logps/rejected': -157.98974609375, 'logps/chosen': -118.56217956542969, 'logits/rejected': -106.94883728027344, 'logits/chosen': -105.9101791381836, 'epoch': 0.56}
{'eval_loss': 0.7917570471763611, 'eval_runtime': 11.0, 'eval_samples_per_second': 9.091, 'eval_steps_per_second': 1.182, 'eval_rewards/chosen': -0.21202737092971802, 'eval_rewards/rejected': -0.2734401226043701, 'eval_rewards/accuracies': 0.5288461446762085, 'eval_rewards/margins': 0.0614127553999424, 'eval_logps/rejected': -141.26211547851562, 'eval_logps/chosen': -114.92829895019531, 'eval_logits/rejected': -108.82428741455078, 'eval_logits/chosen': -109.04977416992188, 'epoch': 0.6}
{'loss': 0.5332, 'learning_rate': 1.8e-05, 'rewards/chosen': -0.6318769454956055, 'rewards/rejected': -1.7545849084854126, 'rewards/accuracies': 0.6875, 'rewards/margins': 1.1227079629898071, 'logps/rejected': -181.09495544433594, 'logps/chosen': -116.23958587646484, 'logits/rejected': -106.14888763427734, 'logits/chosen': -105.6807632446289, 'epoch': 0.64}
{'loss': 0.6824, 'learning_rate': 1.4000000000000001e-05, 'rewards/chosen': -1.0098148584365845, 'rewards/rejected': -1.4946296215057373, 'rewards/accuracies': 0.637499988079071, 'rewards/margins': 0.4848148226737976, 'logps/rejected': -125.6796646118164, 'logps/chosen': -121.78050231933594, 'logits/rejected': -105.04019927978516, 'logits/chosen': -105.38191986083984, 'epoch': 0.72}
{'loss': 0.7617, 'learning_rate': 1e-05, 'rewards/chosen': -0.9180719256401062, 'rewards/rejected': -1.9269310235977173, 'rewards/accuracies': 0.5625, 'rewards/margins': 1.0088589191436768, 'logps/rejected': -174.40341186523438, 'logps/chosen': -111.41725158691406, 'logits/rejected': -105.29117584228516, 'logits/chosen': -103.72891998291016, 'epoch': 0.8}
{'eval_loss': 0.7312757968902588, 'eval_runtime': 10.7689, 'eval_samples_per_second': 9.286, 'eval_steps_per_second': 1.207, 'eval_rewards/chosen': -0.3517668843269348, 'eval_rewards/rejected': -0.5344953536987305, 'eval_rewards/accuracies': 0.5865384340286255, 'eval_rewards/margins': 0.18272852897644043, 'eval_logps/rejected': -143.87269592285156, 'eval_logps/chosen': -116.32568359375, 'eval_logits/rejected': -108.92617797851562, 'eval_logits/chosen': -109.03649139404297, 'epoch': 0.8}
{'loss': 0.6901, 'learning_rate': 6e-06, 'rewards/chosen': -0.8109361529350281, 'rewards/rejected': -1.3173637390136719, 'rewards/accuracies': 0.6000000238418579, 'rewards/margins': 0.5064277648925781, 'logps/rejected': -139.4099578857422, 'logps/chosen': -116.5167007446289, 'logits/rejected': -107.23912048339844, 'logits/chosen': -106.95794677734375, 'epoch': 0.88}
{'loss': 0.93, 'learning_rate': 2.0000000000000003e-06, 'rewards/chosen': -1.6479976177215576, 'rewards/rejected': -1.9559228420257568, 'rewards/accuracies': 0.574999988079071, 'rewards/margins': 0.3079252243041992, 'logps/rejected': -195.25650024414062, 'logps/chosen': -162.7464141845703, 'logits/rejected': -106.31349182128906, 'logits/chosen': -106.6192398071289, 'epoch': 0.96}
{'eval_loss': 0.7238949537277222, 'eval_runtime': 9.8789, 'eval_samples_per_second': 10.123, 'eval_steps_per_second': 1.316, 'eval_rewards/chosen': -0.3252321481704712, 'eval_rewards/rejected': -0.4981166124343872, 'eval_rewards/accuracies': 0.6057692170143127, 'eval_rewards/margins': 0.1728844791650772, 'eval_logps/rejected': -143.50889587402344, 'eval_logps/chosen': -116.06034851074219, 'eval_logits/rejected': -108.62019348144531, 'eval_logits/chosen': -108.63538360595703, 'epoch': 1.0}
{'train_runtime': 199.7196, 'train_samples_per_second': 5.007, 'train_steps_per_second': 1.252, 'train_loss': 0.7629560508728027, 'epoch': 1.0}                                           
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [03:15<00:00,  1.28it/s] 
wandb: Waiting for W&B process to finish... (success).
wandb: - 0.001 MB of 0.001 MB uploaded (0.000 MB deduped)
wandb: Run history:
wandb:             eval/logits/chosen ▁▄▇▇█
wandb:           eval/logits/rejected ▁▃█▇█
wandb:              eval/logps/chosen ▄█▆▁▂
wandb:            eval/logps/rejected █▆▆▁▂
wandb:                      eval/loss █▃▆▂▁
wandb:        eval/rewards/accuracies ▁▅▅▇█
wandb:            eval/rewards/chosen ▄█▆▁▂
wandb:           eval/rewards/margins ▁▆▅██
wandb:          eval/rewards/rejected █▆▆▁▂
wandb:                   eval/runtime ▃▁█▇▄
wandb:        eval/samples_per_second ▅█▁▂▄
wandb:          eval/steps_per_second ▅█▁▂▄
wandb:                    train/epoch ▁▂▂▂▃▃▃▄▅▅▅▆▆▆▇███
wandb:              train/global_step ▁▂▂▂▃▃▃▄▅▅▅▆▆▆▇███
wandb:            train/learning_rate █▇▇▆▅▅▄▄▃▂▂▁
wandb:            train/logits/chosen ▅▁▂▅▃▃▅▆▆█▄▅
wandb:          train/logits/rejected ▆▁▂▅▃▁▅▆██▄▆
wandb:             train/logps/chosen ▅▄▃▇▄█▆▇▆▇▆▁
wandb:           train/logps/rejected ▅▆▃▆▂▇▅▂█▃▇▁
wandb:                     train/loss ▇█▆▅▅▄▆▁▄▅▄█
wandb:       train/rewards/accuracies ▁▃▅▁▅▅▅█▆▃▄▃
wandb:           train/rewards/chosen ▆▄▅▆▅▇▆█▅▆▇▁
wandb:          train/rewards/margins ▁▃▃▃▅▄▃█▄▇▄▃
wandb:         train/rewards/rejected █▅▄▆▃▆▆▃▅▁▆▁
wandb:               train/total_flos ▁
wandb:               train/train_loss ▁
wandb:            train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb:   train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:             eval/logits/chosen -108.63538
wandb:           eval/logits/rejected -108.62019
wandb:              eval/logps/chosen -116.06035
wandb:            eval/logps/rejected -143.5089
wandb:                      eval/loss 0.72389
wandb:        eval/rewards/accuracies 0.60577
wandb:            eval/rewards/chosen -0.32523
wandb:           eval/rewards/margins 0.17288
wandb:          eval/rewards/rejected -0.49812
wandb:                   eval/runtime 9.8789
wandb:        eval/samples_per_second 10.123
wandb:          eval/steps_per_second 1.316
wandb:                    train/epoch 1.0
wandb:              train/global_step 250
wandb:            train/learning_rate 0.0
wandb:            train/logits/chosen -106.61924
wandb:          train/logits/rejected -106.31349
wandb:             train/logps/chosen -162.74641
wandb:           train/logps/rejected -195.2565
wandb:                     train/loss 0.93
wandb:       train/rewards/accuracies 0.575
wandb:           train/rewards/chosen -1.648
wandb:          train/rewards/margins 0.30793
wandb:         train/rewards/rejected -1.95592
wandb:               train/total_flos 0.0
wandb:               train/train_loss 0.76296
wandb:            train/train_runtime 199.7196
wandb: train/train_samples_per_second 5.007
wandb:   train/train_steps_per_second 1.252
wandb:
wandb:  View run lunar-elevator-1274 at: [link]
wandb:  View job at [link]
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: [path]

and it produced the following W&B logs:
https://wandb.ai/tomaarsen/huggingface/reports/TRL-DPOTrainer-on-GPT2--Vmlldzo0OTY3MzAx
Crucially:
image

What now?

I'd love some feedback: Is it too much logging? In short, do we need to log logps and logits?

Feedback & Discussion is welcome as always.

cc: @kashif

  • Tom Aarsen

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 25, 2023

The documentation is not available anymore as the PR was closed or merged.

Whoops, hadn't run `pre-commit install` yet
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work @tomaarsen , as always! Thanks very much for this detailed PR and very clean fix!

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the very clean PR!

@lvwerra lvwerra merged commit b3c2e73 into huggingface:main Jul 26, 2023
@soumyasanyal
Copy link

soumyasanyal commented Jun 13, 2024

Hi @tomaarsen,

I'm trying to understand the external logging in DPO. I don't see where the logging object gets created. For instance, in PPO, we can access the trackers as follows: ppo_trainer.accelerator.trackers. But the same thing is not defined for DPO. The only things I was able to figure out are Accelerator is not defined in the DPO code and the DPO code inherits a different base class from the transformers library, Trainer. I tried to check inside the trainer but it does not have much details about the trackers.

My use case: I want to use the wandb logging object to push more things such as hyper-parameters to the dashboard. For PPO I would do the following and it worked:

wandb_tracker = ppo_trainer.accelerator.get_tracker('wandb', unwrap=True)
wandb_tracker.config.update({'exp_args': vars(args)})

I want to do something similar for DPO. Any pointers would be great. Thanks!


Update: Found a hack to make it work. Depending on the usecase, callbacks in transformers might have other things first, so the index (1) will have to be figured out for the specific case.

dpo_trainer.callback_handler.callbacks[1]._wandb.init()
dpo_trainer.callback_handler.callbacks[1]._wandb.config.update({'exp_args': vars(args)})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DPOTrainer logging too frequent support wandb log in dpo
5 participants