Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable mlm by default in DataCollatorForCompletionOnlyLM, add ignore_index and docstring #476

Merged
merged 4 commits into from
Jul 6, 2023
Merged

Conversation

BramVanroy
Copy link
Contributor

As discussed here, DataCollatorForCompletionOnlyLM inherits from the MLM data collator but that means that mlm is set to True by default. In reality that does not matter (because the appropriate methods are overridden) but it can still be confusing for the user. So now mlm is explicitly set to False (with the option to be changed by the user).

The user now also has the option to decide which index to use to ignore ignore_index.

A docstring is added for DataCollatorForCompletionOnlyLM, which was missing. I urge future PRs to always include good documentation and docstrings, otherwise the overview gets lost quickly.

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BramVanroy for fixing - I left two small comments!

@@ -81,7 +96,7 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
response_token_ids_end_idx = response_token_ids_start_idx + len(response_token_ids)

# Make pytorch loss function ignore all tokens up through the end of the response key
labels[i, :response_token_ids_end_idx] = -100
labels[i, :response_token_ids_end_idx] = self.ignore_index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this makes sense here, since -100 is the default value ignored by torch loss functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra Indeed, it is the default but can be changed. Soif users set a different ignore index, they may want to use a different one in the trl library as well. I think it is always a good idea to allow customization with sensible defaults so that expert users have the chance to modify what they need.

Args:
response_template (`str`): the template form that indicates the start of the response, typically something like
'### Response:\n'
mlm (`bool`, *optional*, defaults to `False`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't support mlm so I don't think we need to support it here. so i would just hardcode it to False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, although I just followed the recommendation of @younesbelkada

#445 (comment)

If you both agree to disable it, I can hardcode it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go with the hardcoded variant. I'll discuss with @younesbelkada when he gets back from vacation in case he feels there is a use-case where someone wants to use it.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 3, 2023

The documentation is not available anymore as the PR was closed or merged.

@BramVanroy
Copy link
Contributor Author

@lvwerra It doesn't seem straightforward to do the quality fixes automatically from my end. The makefile contains a precommit key but that seems specific to Github (with the related yaml file). Unlike transformers I am missing a make style. The CONTRIBUTING guide mentions make commit but that command does not exist.

@lvwerra
Copy link
Member

lvwerra commented Jul 4, 2023

Did you install the pre-commit library?

pip install pre-commit
make precommit

"""

def __init__(self, response_template: str, *args, ignore_index: int = -100, **kwargs):
super().__init__(*args, mlm=False, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. here we pass the mlm explicitly which is the reason the tests fail:
https://github.com/lvwerra/trl/blob/2b531b92238f9fc40c6cc9cf27b8bbd3ae2727b1/trl/trainer/sft_trainer.py#L170

We should check if mlm is a key in kwargs and only update with False if not.

Copy link
Contributor Author

@BramVanroy BramVanroy Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't that make it the same as the initial PR (not hardcoded) but just less transparent?

Original:

    def __init__(self, response_template: str, *args, mlm: bool = False, ignore_index: int = -100, **kwargs):
        super().__init__(*args, mlm=mlm, **kwargs)

Requested refactor:

    def __init__(self, response_template: str, *args, ignore_index: int = -100, **kwargs):
        super().__init__(*args, mlm=kwargs.pop("mlm", False), **kwargs)

From a user/typing/IDE completion, I'd argue the first one is the better choice, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, let's go with the original one. Sorry for going in a circle to find the best solution :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem ;-) Re-added

@lvwerra lvwerra merged commit 25d4d81 into huggingface:main Jul 6, 2023
@BramVanroy BramVanroy deleted the mlm_collator branch July 6, 2023 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants