generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
⏩ Train on completion only #3329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+94
−157
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
538ec20
Train on completion only
qgallouedec 34b7017
a bit of documentation
qgallouedec 3a2788b
Merge branch 'fix-add_special_tokens' into train-completion-only
qgallouedec 9821faa
minor refinement and fix test
qgallouedec 5ea8495
now it's fixed!
qgallouedec d8b3e47
allow none
qgallouedec a5dd899
minors
qgallouedec 0de43b8
title
qgallouedec 7f7f2a4
update example
qgallouedec eb9d7d1
Merge branch 'fix-add_special_tokens' into train-completion-only
qgallouedec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,6 +78,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): | |
Args: | ||
pad_token_id (`int`): | ||
Token ID to use for padding. | ||
completion_only_loss (`bool`, *optional*, defaults to `True`): | ||
When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens | ||
that are not in the completion. | ||
return_tensors (`str`, *optional*, defaults to `"pt"`): | ||
Type of Tensor to return. Only `"pt"` is currently supported. | ||
|
||
|
@@ -90,29 +93,47 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): | |
... {"input_ids": [4, 5]} | ||
... ] | ||
>>> collator(examples) | ||
{'input_ids': tensor([[ 1, 2, 3], | ||
[ 4, 5, 0]]), | ||
'attention_mask': tensor([[ 1, 1, 1], | ||
[ 1, 1, 0]]), | ||
{'input_ids': tensor([[ 1, 2, 3], | ||
[ 4, 5, 0]]), | ||
'attention_mask': tensor([[ 1, 1, 1], | ||
[ 1, 1, 0]]), | ||
'labels': tensor([[ 1, 2, 3], | ||
[ 4, 5, -100]])} | ||
>>> # With completion mask | ||
>>> examples = [ | ||
... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, | ||
... {"input_ids": [4, 5], "completion_mask": [0, 1]} | ||
... ] | ||
>>> collator(examples) | ||
{'input_ids': tensor([[ 1, 2, 3], | ||
[ 4, 5, 0]]), | ||
'attention_mask': tensor([[ 1, 1, 1], | ||
[ 1, 1, 0]]), | ||
'labels': tensor([[-100, 2, 3], | ||
[-100, 5, -100]])} | ||
``` | ||
""" | ||
|
||
pad_token_id: int | ||
completion_only_loss: bool = True | ||
return_tensors: str = "pt" | ||
|
||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: | ||
# Convert to tensor | ||
input_ids = [torch.tensor(example["input_ids"]) for example in examples] | ||
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] | ||
labels = [torch.tensor(example["input_ids"]) for example in examples] | ||
if self.completion_only_loss and "completion_mask" in examples[0]: | ||
completion_mask = [torch.tensor(example["completion_mask"]) for example in examples] | ||
|
||
# Pad | ||
output = {} | ||
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right") | ||
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right") | ||
output["labels"] = pad(labels, padding_value=-100, padding_side="right") | ||
if self.completion_only_loss and "completion_mask" in examples[0]: | ||
completion_mask = pad(completion_mask, padding_value=0, padding_side="right") | ||
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion | ||
|
||
return output | ||
|
||
|
@@ -278,6 +299,11 @@ def __init__( | |
) | ||
data_collator = DataCollatorWithFlattening() | ||
|
||
if args.completion_only_loss is None: | ||
first_example = next(iter(train_dataset)) | ||
self.completion_only_loss = "prompt" in first_example | ||
else: | ||
self.completion_only_loss = args.completion_only_loss | ||
if data_collator is None: | ||
# Get the pad token: if not provided, use the one from the processing class or the eos token | ||
# if the processing class does not have a pad token. | ||
|
@@ -289,7 +315,7 @@ def __init__( | |
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " | ||
"in the vocabulary before using it as a padding token." | ||
) | ||
data_collator = DataCollatorForLanguageModeling(pad_token_id) | ||
data_collator = DataCollatorForLanguageModeling(pad_token_id, self.completion_only_loss) | ||
|
||
# Dataset | ||
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) | ||
|
@@ -500,16 +526,6 @@ def _func(example): | |
) | ||
dataset = dataset.map(_func, batched=True, **map_kwargs) | ||
|
||
# If the dataset is prompt-completion, convert it to language modeling type | ||
first_example = next(iter(dataset)) | ||
if "prompt" in first_example.keys() and "completion" in first_example.keys(): | ||
key = "messages" if is_conversational(first_example) else "text" | ||
|
||
def concat_prompt_completion(example): | ||
return {key: example["prompt"] + example["completion"]} | ||
|
||
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"]) | ||
|
||
Comment on lines
-503
to
-512
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This concatenation needs to be removed, as we loses the information about where the completion starts. This completion is now managed in |
||
if not is_processed: | ||
# Convert the dataset to ChatML if needed | ||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | ||
|
@@ -560,14 +576,38 @@ def add_eos(example, eos_token): | |
# See https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens | ||
add_special_tokens = True | ||
|
||
# Tokenize the dataset if needed | ||
# Tokenize the dataset | ||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | ||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" | ||
|
||
def tokenize(example, processing_class, dataset_text_field, add_special_tokens): | ||
processed = processing_class( | ||
text=example[dataset_text_field], add_special_tokens=add_special_tokens | ||
) | ||
if "prompt" in example: # prompt-completion case | ||
processed_prompt = processing_class( | ||
text=example["prompt"], | ||
add_special_tokens=add_special_tokens, | ||
) | ||
processed = processing_class( | ||
text=example["prompt"] + example["completion"], add_special_tokens=add_special_tokens | ||
) | ||
|
||
# Check if the tokenized prompt starts with the tokenized prompt+completion | ||
prompt_ids = processed_prompt["input_ids"] | ||
prompt_completion_ids = processed["input_ids"] | ||
if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: | ||
warnings.warn( | ||
"Mismatch between tokenized prompt and the start of tokenized prompt+completion. " | ||
"This may be due to unexpected tokenizer behavior, whitespace issues, or special " | ||
"token handling. Verify that the tokenizer is processing text consistently." | ||
) | ||
|
||
# Create a completion mask | ||
completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) | ||
processed = {**processed, "completion_mask": completion_mask} | ||
|
||
else: # language modeling case | ||
processed = processing_class( | ||
text=example[dataset_text_field], add_special_tokens=add_special_tokens | ||
) | ||
return processed | ||
|
||
dataset = dataset.map( | ||
|
@@ -598,6 +638,14 @@ def tokenize(example, processing_class, dataset_text_field, add_special_tokens): | |
|
||
return dataset | ||
|
||
def _set_signature_columns_if_needed(self): | ||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed. | ||
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" | ||
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the | ||
# dataset. So we need to override the default signature columns to include "completion_mask" as well. | ||
if self._signature_columns is None: | ||
self._signature_columns = ["input_ids", "attention_mask", "completion_mask"] | ||
|
||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||
""" | ||
Compute training loss and additionally compute token accuracies | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not related to the core change of this PR.
With the new serialisation logic of
TrainingArguments
, passing a wrong dtype fails when you instantiate theTrainingArguments
. There is no need for such test anymore