Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EOS token processing for multi-turn DPO #741

Merged
merged 6 commits into from
Sep 12, 2023
Merged

EOS token processing for multi-turn DPO #741

merged 6 commits into from
Sep 12, 2023

Conversation

natolambert
Copy link
Contributor

@natolambert natolambert commented Sep 5, 2023

Instead of asserts, mask out EOS tokens in the attention mask to.
my dev setup for TRL is out of day. Will fix the precommit stuff.

CC @kashif

Nathan Lambert added 2 commits September 5, 2023 14:51
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@natolambert natolambert requested a review from lvwerra September 7, 2023 14:50
@kashif
Copy link
Collaborator

kashif commented Sep 7, 2023

@natolambert the recent seq-2-seq PR might have caused some merge conflicts

@natolambert
Copy link
Contributor Author

kk @kashif and @lvwerra merge conflict should be fixed now. Will double check via my testing with H4 this afternoon!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the eos token support for DPO!
My tiny suggestion would be to replace the indices logic by something like

attention_mask = ~torch.Tensor(prompt_tokens["input_ids"]).ne(eos_token_id)
prompt_tokens["attention_mask"] = attention_mask.tolist()

This looks already great though ! Feel free to merge as it is in case you prefer your approach (mine requires to first convert it to torch tensor and convert back the attention mask to a list)

@natolambert
Copy link
Contributor Author

natolambert commented Sep 12, 2023

merging this as a starting point, expect more DPO improvements and PRs soon!
I forget if I'm supposed to do that with TRL 😅 let me know.

@natolambert natolambert merged commit 9141aa4 into main Sep 12, 2023
@natolambert natolambert deleted the dpo_token_fix branch September 12, 2023 16:49
kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
* init

* fix

* add doc

* style

* clarify example
@robertgshaw2-neuralmagic
Copy link

robertgshaw2-neuralmagic commented Nov 12, 2023

@natolambert

I am just curious, what is the reason for setting the attention mask to 0 for all inputs with eos_token_id?

My thinking is that once these models have been aligned with DPO, they will typically be used with the chat templates (as in the Zephyr model card https://huggingface.co/HuggingFaceH4/zephyr-7b-beta#intended-uses--limitations)

In that model card, the eos_token is present in the prompt for single and multi-turn generation

import torch
from transformers import pipeline

pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto")

# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
    {
        "role": "system",
        "content": "You are a friendly chatbot who always responds in the style of a pirate",
    },
    {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])
# <|system|>
# You are a friendly chatbot who always responds in the style of a pirate.</s>
# <|user|>
# How many helicopters can a human eat in one sitting?</s>
# <|assistant|>
# Ah, me hearty matey! But yer question be a puzzler! A human cannot eat a helicopter in one sitting, as helicopters are not edible. They be made of metal, plastic, and other materials, not food!

So during inference, the eos_token_ids will have attention_mask=1. Shouldn't we mirror this during training?

@natolambert
Copy link
Contributor Author

natolambert commented Nov 12, 2023

Hey @rsnm2 - it was really a hack to make things work at all. It would be nice to revisit it. There was some technical issue during training that made that not work.

@robertgshaw2-neuralmagic

Hey @natolambert - makes sense

I think the solution would just be to remove all of this logic (instead leaving the attention_mask=1 for the eos_token_ids, which is the default in tokenizers) unless there is a good reason to ignore the eos_tokens during training (I dont think there is)

@natolambert
Copy link
Contributor Author

You should try reverting it and playing with it. IIRC the Transformers Trainer errors out with no processing / truncation.
@rsnm2

lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* init

* fix

* add doc

* style

* clarify example
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants