-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix kwargs
handling in generate_with_fallback
#29225
Fix kwargs
handling in generate_with_fallback
#29225
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @cifkao, thanks for the great work here, it's a nice catch.
The fix seems okay to me, I don't think we have a way to test if it does work, otherwise I'd have ask you that!
@sanchit-gandhi could we have your review here as well ?
|
||
generate_kwargs = dict(kwargs) | ||
for key in ["do_sample", "temperature", "num_beams"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
temperature
shouldn't be in kwargs
as it's already an argument of .generate
here right ?
It seems okay to check for do_sample
and num_beams
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to be extra cautious here and make sure everything is safe locally, rather than relying on what gets passed down from 2 call frames up the stack. But I can remove temperature
if you prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me as is - there's a preference for more explicit handling of kwargs than more buried ones
@@ -759,6 +759,8 @@ def generate_with_fallback( | |||
do_condition_on_prev_tokens, | |||
kwargs, | |||
): | |||
kwargs = dict(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you use dict(...)
here and below? Is it to copy ? If yes, shouldn't we use copy.deepcopy
instead ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's just to make a copy. My thinking here was that a shallow copy (using dict()
or copy.copy()
) has the same effect as using the **kwargs
syntax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy.deepcopy
should make the trick then right ? I'm just afraid that using dict
might no be self-explanatory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't think we want to make a deep copy (what if kwargs
contains a large object like assistant_model
, for example?). So I changed the dict
to copy.copy
, which is equivalent and more readable.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @cifkao! Sounds good, just want to make sure you have a reproducer!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great issue and super clear PR description @cifkao! The PR looks good to me. My only request is that we add a test to confirm beam search is working as expected. Could we modify your reproducer to do this, possibly with something like the following?
import datasets
from transformers import AutoProcessor, GenerationMixin, WhisperForConditionalGeneration
import numpy as np
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
orig_generate = GenerationMixin.generate
NUM_BEAMS = 2
def generate(self, *args, **kwargs):
assert args[1].num_beams == NUM_BEAMS
return orig_generate(self, *args, **kwargs)
GenerationMixin.generate = generate
ds = datasets.load_dataset(
"google/fleurs", "en_us", split="test", trust_remote_code=True
)
ds = ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16000))
raw_audio = np.concatenate([x["array"].astype(np.float32) for x in ds[:16]["audio"]])
inputs = processor(
[raw_audio],
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=16_000,
)
model.generate(
**inputs,
num_beams=NUM_BEAMS,
task="transcribe",
language="en",
)
@sanchit-gandhi Test added! On
After fix:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating and adding a test!
* Fix generate_with_fallback **kwargs * Change pop to get * Delete keys from kwargs to prevent overriding generation_config * Revert to passing kwargs by reference, but make a (shallow) copy * dict -> copy.copy * Add test_whisper_longform_multi_batch_beam
* Fix generate_with_fallback **kwargs * Change pop to get * Delete keys from kwargs to prevent overriding generation_config * Revert to passing kwargs by reference, but make a (shallow) copy * dict -> copy.copy * Add test_whisper_longform_multi_batch_beam
What does this PR do?
Fixes #29312.
pop()
toget()
to avoid modifyingkwargs
between loop iterations,kwargs
is made as the first step ingenerate_with_fallback()
to prevent any changes to it from propagating outside the method call.generation_config
are removed from the keyword arguments tosuper().generate()
(to avoid overriding the former), but this is done in a copy ofkwargs
that is not reused between iterations.Before submitting
Who can review?
@patrickvonplaten @sanchit-gandhi @ylacombe