Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT support for Online DPO #2041

Merged
merged 31 commits into from
Sep 13, 2024
Merged

PEFT support for Online DPO #2041

merged 31 commits into from
Sep 13, 2024

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Sep 9, 2024

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member Author

qgallouedec commented Sep 9, 2024

accelerate launch examples/scripts/dpo_online.py
    --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
    --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
    --dataset_name trl-lib/tldr \
    --output_dir pythia-1b-tldr-online-dpo \
    --learning_rate 5.0e-7 \
    --logging_first_step \
    --logging_steps 10 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 8 \
    --warmup_ratio 0.1 \
    --missing_eos_penalty 1.0 \
    --push_to_hub \
    --dataset_num_proc 32 \
    --use_peft

PEFT stabilises training a lot!
Note that orange training suffer from overoptimization here

Screenshot 2024-09-13 at 10 28 39

@qgallouedec qgallouedec changed the title Further support of Online DPO PEFT support for Online DPO Sep 13, 2024
@qgallouedec qgallouedec marked this pull request as ready for review September 13, 2024 08:25
@kashif
Copy link
Collaborator

kashif commented Sep 13, 2024

can you kindly add the above peft/lora usage command in the script?

@qgallouedec
Copy link
Member Author

Particular attention for reviewing this one. I didn't use the code from DPO mostly because I don't understand it. So I'm afraid to have missed something.

  1. Why do you need to merge and unload?

    # if model is a peft model and we have a peft_config, we merge and unload it first
    if isinstance(model, PeftModel):
    model = model.merge_and_unload()

  2. I don't support k-bit training yet

    if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
    _support_gc_kwargs = hasattr(
    args, "gradient_checkpointing_kwargs"
    ) and "gradient_checkpointing_kwargs" in list(
    inspect.signature(prepare_model_for_kbit_training).parameters
    )
    prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
    if _support_gc_kwargs:
    prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
    model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)

  3. Can we drop it for new trainer?

    model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
    elif getattr(args, "gradient_checkpointing", False):
    # For backward compatibility with older versions of transformers
    if hasattr(model, "enable_input_require_grads"):
    model.enable_input_require_grads()
    else:
    def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

  4. Why do we need to cast? And why only when loaded in 4-bits. Btw why always using the getattr? We know that this arg exists in TrainingArguments

    if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
    peft_module_casting_to_bf16(model)
    # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
    self._peft_has_been_casted_to_bf16 = True

Overall I've choose to go for a code that I understand, even if it implies introducing bug that have been fixed in the past for other trainers.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding PEFT support! Overall LGTM - have you run an experiment on e.g. TLDR to see if it looks OK?

Edit: sorry, just saw your comment!

examples/scripts/dpo_online.py Outdated Show resolved Hide resolved
examples/scripts/dpo_online.py Show resolved Hide resolved
examples/scripts/dpo_online.py Show resolved Hide resolved
examples/scripts/dpo_online.py Show resolved Hide resolved
qgallouedec and others added 3 commits September 13, 2024 12:24
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
@lewtun
Copy link
Member

lewtun commented Sep 13, 2024

  1. Why do you need to merge and unload?

This is needed if we have e.g. an SFT LoRA that needs to be merged into the base model, before finally inserting an adapter for DPO. See here for some considerations.

  1. I don't support k-bit training yet

Sure, let's leave this as follow-up PR

  1. Can we drop it for new trainer?

Fine with me. We can add it later if people open an issue.

  1. Why do we need to cast? And why only when loaded in 4-bits. Btw why always using the getattr? We know that this arg exists in TrainingArguments

Good question, this was added by Younes but maybe @BenjaminBossan can help clarify here why PEFT models need upcasting in 4-bit?

@BenjaminBossan
Copy link
Member

7. Why do we need to cast? And why only when loaded in 4-bits. Btw why always using the getattr? We know that this arg exists in TrainingArguments

Good question, this was added by Younes but maybe @BenjaminBossan can help clarify here why PEFT models need upcasting in 4-bit?

The original addition of this function was from #1110 (which in turn references this repo). I don't know the exact context, but I think the reason is to avoid indiscriminately casting all layers to bf16 -- specifically, layer norm stays in float32. This is probably based on some empirical finding that this is better for training, but after skimming the QLoRA paper, I could not find any mention of that, so I'm unsure.

@qgallouedec
Copy link
Member Author

Screenshot 2024-09-13 at 16 20 58

@qgallouedec
Copy link
Member Author

Thanks a lot @BenjaminBossan and @lewtun. I'll further investigate in a dedicated branch.

@qgallouedec qgallouedec merged commit ebc85b2 into main Sep 13, 2024
10 checks passed
@qgallouedec qgallouedec deleted the kaist branch September 13, 2024 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants