-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling users to provide their own stopping_criteria
+ logits_processor
to generate
.
#12219
Enabling users to provide their own stopping_criteria
+ logits_processor
to generate
.
#12219
Conversation
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@patrickvonplaten (Not urgent, get some rest :)) |
Sorry for the late reply here @Narsil - I'm happy with the PR I think :-) If we could add a test that would be great |
`logits_processor` to `generate`.
463a1ba
to
5e5db45
Compare
stopping_criteria
+ logits_processor
to generate
.stopping_criteria
+ logits_processor
to generate
.
@patrickvonplaten Should I merge this ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this :)
@patrickvonplaten do you want to take a look again ?
src/transformers/generation_utils.py
Outdated
) -> StoppingCriteriaList: | ||
stopping_criteria = StoppingCriteriaList() | ||
if stopping_criteria is None: | ||
stopping_criteria = StoppingCriteriaList() | ||
if max_length is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_length always defaults to 20
-> so if someone passes a stopping_criteria
list then there are two stopping criteria no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not good no? E.g. if the stopping criteria is 30 in the list, the generation will still stop at 20. So IMO if someone passes a stopping_criteria
list we should check for each item if the class already exsits in the list (if it's not the case, only then we'll add it). This means that the priority is as follows:
1st priority: stopping_criteria
2nd priority: directly passing max_length
3rd priority: using max_length
of config
4th priority: using default max_length
=> think the same should hold true for logits_processor
.
Think we should not merge as it is right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think the workflow is not optimal at the moment -> see comment here: https://github.com/huggingface/transformers/pull/12219/files#r705307868
Keen to hear your opinion @Narsil
I think we shouldn't check anything. If you defined something we pass it |
But also happy to drop the PR, the issue didn't seem to generate that much traction. |
I think it would be nice to merge the PR, but it just doesn't make much sense to me that a default, always-defined value like |
So, you think, we should if logits_processor is None:
logist_processort = self._get_logits_process(...) instead ? Make sense. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Leaving it as closed for now - reopening in case the community expresses interest in this PR again... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't approve since this is my PR, but it LGTM.
if isinstance(stopping_criterion, MaxLengthCriteria): | ||
if max_length is not None: | ||
warnings.warn( | ||
"A stopping criteria of type `MaxLengthCriteria` as well as `max_length` was passed to `generate`. The `MaxLengthCriteria` will be used.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not necessarily true.
On line 889, we can defined max_length
to self.config.max_length
even if it's not user defined.
At least we override here which makes behavior what a user would expect IMO (respecting it's own defined things)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh that's right, I missed that one. I think the stopping_criteria
should be moved before that. Otherwise you will always get a warning if neither max_length is None and max_new_tokens is None
which would be the standard case when using a custom MaxLengthCriteria
. What do you think?
Thanks a lot for taking this over @lvwerra ! Let me know if you need any help with the remaining tests |
if stopping_criteria is None: | ||
stopping_criteria = StoppingCriteriaList() | ||
max_length_in_criteria = any([isinstance(criteria, MaxLengthCriteria) for criteria in stopping_criteria]) | ||
if max_length is not None and not max_length_in_criteria: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with this I think -> it means that stopping_criteria
is always more important than max_length
which sounds good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually thinking more about I think it would be cleaner to just have the following logic:
if stopping_criteria
is provided than only this list is used and nothing else. In a first step this allows us some nasty nested use cases. IMO we could just check if len(stopping_criteria) > 0
and if this is the case we don't even call the function _get_stopping_criteria
. IMO someone that uses that functionality understands generate()
quite well and doesn't need much magic under-the-hood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think @Narsil ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's an option that is definitely viable.
The core things I think are important:
- If a user specified something, we need to respect it
- If something comes as a default it cannot override anything user specified.
- If user specification are unclear/unsound, yell loud and clear about what is going on, and what the code is going to do to save the
generate
.
Option 1 (current):
model.generate(...., no_repeat_n_tokens=3)
I need to add some even more clever functionality
mylogitsProcessor=LogitsProcessorList(MyLogits())
model.generate(...., logits_processor=mylogitsProcessor, no_repeat_n_tokens=3)
(You can keep the easiness of generate
).
Option 2 (logits_processor is a full override):
model.generate(..., no_repeat_n_tokens=3)
becomes
my_logits_processor = LogitsProcessorList(
NoRepeatLogitsProcessor(3),
MyLogits()
)
model.generate(..., logits_processor=logits_processor)
Option 1 has the advantage that we can keep some options simpler to use and still add some custom logits processor, if we're careful enough that no non-user defined variable can ever override the logits_processor
in a hidden way, then we're good to go. (Only user defined arguments are able to modify). This is ofc a guarantee that might be tricky to keep in the future, so we are taking a risk of silently breaking things. It also makes user code trickier to understand since there is no ONE way to define logits_processor, so it might lead to hard to understand behavior.
Option 2 has the advantage that it has a single point of definition of logits processor. The disadvantage is that it requires more changes on the user the first time he uses this variable. We also need to yell quite strongly that we're simply ignoring every other argument, which might not be obvious (we could even crash at this point, since it's a really easy to overlook thing and will definitely yield poor results).
IMHO, both are really fine, we just need to stick to one option. I suggested to @lvwerra Option 1, because I thought that max_length
was the only variable that was always defined even if not user supplied. If this assumption is wrong and hard to make on a small list of variable, then I don't think we should stick with Option 1. Option 2 does have drawbacks as we're rejecting complexity back into user code instead of absorbing it like we're doing right now. But it does help separation of concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great write-up! Actually there are a bunch of things that are always defined (top_k is always defined e.g.) but then also lots of models always define num_beams
in their config alongside other parameters so I would very much prefer option 2 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lvwerra are you ok switching to Option 2, seems my assumption was incorrect :( ?
max_length: Optional[int], | ||
max_time: Optional[float], | ||
max_new_tokens: Optional[int], | ||
start_length: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start_length is not needed no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is max_new_tokens
needed for? I think this PR was from quite some time ago where we had this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think they come from my attempt to merge the main branch into this one. I'll fix this.
) -> LogitsProcessorList: | ||
""" | ||
This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant | ||
:obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. | ||
""" | ||
processors = LogitsProcessorList() | ||
|
||
if logits_processor is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I also think we shouldn't even call the function if logits_processor
is used. It makes our live much easier and the design a bit cleaner.
What I don't like about the current design is that:
- if say both
forced_eos_token_id
is provided ingenerate()
and as an input at the moment, then we have twoforced_eos_token_id
processors in the list which leads to weird behavior...
A first simple solution is to just not call this function IMO. We could always later adapt it for more advanced functionality if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to keep in mind is that we essentially deactivate a whole bunch of options without much transparency about it so a user would need to know which options correspond to a logits_processor
. E.g. I just had to look up if temperature
would be affected if I added a logits_processor
. Same with stopping_criteria
.
The other way around it is more transparent I think if something goes wrong: Ah I passed forced_eos_token_id
and a logits_processor
doing the same thing, maybe that is not good.
For me the main use-case is to add a custom processer/criteria in addition to what I can already do with the kwargs. Otherwise, why go the extra step of using the arguably more complicated API if you can just pass it as an argument.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's an argument in favour of @Narsil's option 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good argument @lvwerra and I fully understand what you mean.
I'm however really concerned about the complexity that option 1 adds for IMO very few use cases. Also from a backwards breaking point of view it's pretty much impossible to go from option 1 to option 2 in the future if this feature becomes more important where as it's much easier to go from option 2 to option 1 in the future.
For me the main use-case is to add a custom processer/criteria in addition to what I can already do with the kwargs. Otherwise, why go the extra step of using the arguably more complicated API if you can just pass it as an argument.
Very good point and that's the big drawback of option 2. I believe that people having to use special logits processors are able to create it themselves.
Ok, how about we do something in between option 1 and option 2 that doesn't create a crazy complex logic.
If one passes a logits processor, we do the following:
- just create the normal logits_processor that would have been created without passing one.
- if any object of the passed logits processor is in the already created logits processor then we raise an error and tell the user which logits processor was created twice (and ideally which paramter has to be changed for this)
- If there is no error make a union of the two lists and we're good to go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this compromise, what do you think @Narsil?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think we can go for the solution if you want @lvwerra :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that.
@@ -792,6 +806,12 @@ def generate( | |||
crash. Note that using ``remove_invalid_values`` can slow down generation. | |||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |||
logits_processor (:obj:`LogitsProcessorList`, `optional`): | |||
This object is created automatically from other arguments of this function. `logits_processor` is meant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge fan of the docstring here - could we maybe rephrase it a bit. I think the user is mostly interested in what happens when this object is passed and how it can be passed - not really what happens when it is not passed.
So maybe something more like:
This object is created automatically from other arguments of this function. `logits_processor` is meant | |
If provided `logits_processor` will overwrite all passed arguments that can process logits as well as those saved in the model's config. It can be very useful to enable custom logits processing logic. |
We should also note somewhere that this is an experimental feature IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keen to hear your input here as well @Narsil . IMO this feature is really not for the "unexperienced" HF user but for the advanced ones that know more or less what happens under the hood in generate()
(otherwise why need costum logit processors?). To give this functionality while keeping complexity at a minimum I think the best first step is to simply say:
If we pass logits_processor
or stopping_criteria
it will overwrite everything else...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I named Option 2
above. While I think it's viable I don't think it's the only way.
Option 1
which is ("we're adding your stuff too without looking") is also perfectly viable. Let's continue discussion above since I think that's where the main point is, no ?
This object is created automatically from other arguments of this function. `logits_processor` is meant | ||
to be used to add another layer with custom logic. | ||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`): | ||
This object is created automatically from other arguments of this function. `stopping_criteria` is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
Superseeded by #14779 (comment) |
What does this PR do?
Fixes #12118
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.