Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix kwargs handling in generate_with_fallback #29225

Merged

Conversation

cifkao
Copy link
Contributor

@cifkao cifkao commented Feb 22, 2024

What does this PR do?

Fixes #29312.

  1. changes the pop() to get() to avoid modifying kwargs between loop iterations,
  2. makes sure a copy of kwargs is made as the first step in generate_with_fallback() to prevent any changes to it from propagating outside the method call.
  3. makes sure the keys that were assigned to generation_config are removed from the keyword arguments to super().generate() (to avoid overriding the former), but this is done in a copy of kwargs that is not reused between iterations.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@patrickvonplaten @sanchit-gandhi @ylacombe

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @cifkao, thanks for the great work here, it's a nice catch.

The fix seems okay to me, I don't think we have a way to test if it does work, otherwise I'd have ask you that!

@sanchit-gandhi could we have your review here as well ?


generate_kwargs = dict(kwargs)
for key in ["do_sample", "temperature", "num_beams"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temperature shouldn't be in kwargs as it's already an argument of .generate here right ?

It seems okay to check for do_sample and num_beams here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to be extra cautious here and make sure everything is safe locally, rather than relying on what gets passed down from 2 call frames up the stack. But I can remove temperature if you prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me as is - there's a preference for more explicit handling of kwargs than more buried ones

@@ -759,6 +759,8 @@ def generate_with_fallback(
do_condition_on_prev_tokens,
kwargs,
):
kwargs = dict(kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you use dict(...) here and below? Is it to copy ? If yes, shouldn't we use copy.deepcopy instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's just to make a copy. My thinking here was that a shallow copy (using dict() or copy.copy()) has the same effect as using the **kwargs syntax.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy.deepcopy should make the trick then right ? I'm just afraid that using dict might no be self-explanatory

Copy link
Contributor Author

@cifkao cifkao Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think we want to make a deep copy (what if kwargs contains a large object like assistant_model, for example?). So I changed the dict to copy.copy, which is equivalent and more readable.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cifkao! Sounds good, just want to make sure you have a reproducer!

@cifkao cifkao requested a review from ArthurZucker March 31, 2024 00:02
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great issue and super clear PR description @cifkao! The PR looks good to me. My only request is that we add a test to confirm beam search is working as expected. Could we modify your reproducer to do this, possibly with something like the following?

import datasets
from transformers import AutoProcessor, GenerationMixin, WhisperForConditionalGeneration
import numpy as np

processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

orig_generate = GenerationMixin.generate
NUM_BEAMS = 2

def generate(self, *args, **kwargs):
    assert args[1].num_beams == NUM_BEAMS
    return orig_generate(self, *args, **kwargs)


GenerationMixin.generate = generate

ds = datasets.load_dataset(
    "google/fleurs", "en_us", split="test", trust_remote_code=True
)
ds = ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16000))
raw_audio = np.concatenate([x["array"].astype(np.float32) for x in ds[:16]["audio"]])

inputs = processor(
    [raw_audio],
    return_tensors="pt",
    truncation=False,
    padding="longest",
    return_attention_mask=True,
    sampling_rate=16_000,
)

model.generate(
    **inputs,
    num_beams=NUM_BEAMS,
    task="transcribe",
    language="en",
)

@cifkao
Copy link
Contributor Author

cifkao commented Apr 2, 2024

@sanchit-gandhi Test added!

On main:

FAILED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch_beam - assert 1 == 2

After fix:

PASSED tests/models/whisper/test_modeling_whisper.py::WhisperModelIntegrationTests::test_whisper_longform_multi_batch_beam

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating and adding a test!

@ArthurZucker ArthurZucker merged commit bcd42c4 into huggingface:main Apr 3, 2024
17 checks passed
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* Fix generate_with_fallback **kwargs

* Change pop to get

* Delete keys from kwargs to prevent overriding generation_config

* Revert to passing kwargs by reference, but make a (shallow) copy

* dict -> copy.copy

* Add test_whisper_longform_multi_batch_beam
itazap pushed a commit that referenced this pull request May 14, 2024
* Fix generate_with_fallback **kwargs

* Change pop to get

* Delete keys from kwargs to prevent overriding generation_config

* Revert to passing kwargs by reference, but make a (shallow) copy

* dict -> copy.copy

* Add test_whisper_longform_multi_batch_beam
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

All chunks except the first one ignore num_beams in Whisper long-form transcription
5 participants