-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable mlm by default in DataCollatorForCompletionOnlyLM, add ignore_index and docstring #476
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BramVanroy for fixing - I left two small comments!
@@ -81,7 +96,7 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D | |||
response_token_ids_end_idx = response_token_ids_start_idx + len(response_token_ids) | |||
|
|||
# Make pytorch loss function ignore all tokens up through the end of the response key | |||
labels[i, :response_token_ids_end_idx] = -100 | |||
labels[i, :response_token_ids_end_idx] = self.ignore_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if this makes sense here, since -100
is the default value ignored by torch loss functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lvwerra Indeed, it is the default but can be changed. Soif users set a different ignore index, they may want to use a different one in the trl library as well. I think it is always a good idea to allow customization with sensible defaults so that expert users have the chance to modify what they need.
trl/trainer/utils.py
Outdated
Args: | ||
response_template (`str`): the template form that indicates the start of the response, typically something like | ||
'### Response:\n' | ||
mlm (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't support mlm so I don't think we need to support it here. so i would just hardcode it to False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, although I just followed the recommendation of @younesbelkada
If you both agree to disable it, I can hardcode it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's go with the hardcoded variant. I'll discuss with @younesbelkada when he gets back from vacation in case he feels there is a use-case where someone wants to use it.
The documentation is not available anymore as the PR was closed or merged. |
@lvwerra It doesn't seem straightforward to do the quality fixes automatically from my end. The makefile contains a |
Did you install the pip install pre-commit
make precommit |
trl/trainer/utils.py
Outdated
""" | ||
|
||
def __init__(self, response_template: str, *args, ignore_index: int = -100, **kwargs): | ||
super().__init__(*args, mlm=False, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. here we pass the mlm
explicitly which is the reason the tests fail:
https://github.com/lvwerra/trl/blob/2b531b92238f9fc40c6cc9cf27b8bbd3ae2727b1/trl/trainer/sft_trainer.py#L170
We should check if mlm is a key in kwargs
and only update with False
if not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't that make it the same as the initial PR (not hardcoded) but just less transparent?
Original:
def __init__(self, response_template: str, *args, mlm: bool = False, ignore_index: int = -100, **kwargs):
super().__init__(*args, mlm=mlm, **kwargs)
Requested refactor:
def __init__(self, response_template: str, *args, ignore_index: int = -100, **kwargs):
super().__init__(*args, mlm=kwargs.pop("mlm", False), **kwargs)
From a user/typing/IDE completion, I'd argue the first one is the better choice, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, let's go with the original one. Sorry for going in a circle to find the best solution :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem ;-) Re-added
As discussed here,
DataCollatorForCompletionOnlyLM
inherits from the MLM data collator but that means thatmlm
is set to True by default. In reality that does not matter (because the appropriate methods are overridden) but it can still be confusing for the user. So nowmlm
is explicitly set to False (with the option to be changed by the user).The user now also has the option to decide which index to use to ignore
ignore_index
.A docstring is added for DataCollatorForCompletionOnlyLM, which was missing. I urge future PRs to always include good documentation and docstrings, otherwise the overview gets lost quickly.