Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WhisperForCausalLM] Add WhisperForCausalLM for speculative decoding #27195

Merged
merged 10 commits into from
Nov 1, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 31, 2023

What does this PR do?

This PR enables speculative decoding for all cases where the assistant model is stripped of its encoder weights as they are shared with the teacher model. For now, Distil-Whisper is the main use case here.
In addition a WhisperForCausalLM is loaded as it didn't exist yet for Distil-Whisper.

The following code should therefore be enabled:

from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, AutoModelForCausalLM
from datasets import load_dataset
import torch
import time

# load models and processor
processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v2", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.cuda()
assistant_model = AutoModelForCausalLM.from_pretrained("patrickvonplaten/whisper-large-v2-32-2", torch_dtype=torch.float16, low_cpu_mem_usage=True)
assistant_model.cuda()

print(f"Assistant num params compared to teachear {100 * assistant_model.num_parameters() / model.num_parameters()} %.")

# load audio file
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 
input_features = input_features.to(dtype=torch.float16, device="cuda")

# warm-up
_ = model.generate(input_features)

# generate token ids with teacher
start_time = time.time()
predicted_ids = model.generate(input_features)

print("Time normal", time.time() - start_time)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(transcription)
print(20 * "-")

start_time = time.time()
predicted_ids = model.generate(input_features, assistant_model=assistant_model)

print("Time speculative decoding", time.time() - start_time)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(transcription)

@patrickvonplaten patrickvonplaten changed the title finish [WhisperForCausalLM] Add WhisperForCausalLM for speculative decoding Oct 31, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 31, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor Author

The failing tests seem to be unrelated:

FAILED tests/trainer/test_trainer.py::TrainerIntegrationWithHubTester::test_push_to_hub_with_saves_each_n_steps - Failed: Timeout >120.0s
UNEXPECTED EXCEPTION: ChunkedEncodingError(ProtocolError('Connection broken: IncompleteRead(3430997229 bytes read, 2742372923 more expected)', IncompleteRead(3430997229 bytes read, 2742372923 more expected)))
FAILED tests/models/marian/test_modeling_marian.py::MarianModelTest::test_save_load_keys_to_ignore_on_save - FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmpvv597zd7/pytorch_model.bin'
FAILED tests/models/prophetnet/test_modeling_prophetnet.py::ProphetNetModelTest::test_causal_lm_from_pretrained - AssertionError: False is not true
FAILED tests/models/seamless_m4t/test_modeling_seamless_m4t.py::SeamlessM4TGenerationTest::test_speech_generation - Failed: Timeout >120.0s
FAILED tests/models/seamless_m4t/test_modeling_seamless_m4t.py::SeamlessM4TGenerationTest::test_text_generation - AssertionError: Lists differ: [3, 4, 8, 3, 3, 4, 3, 0] != [3, 4, 8, 7, 1, 11, 7, 18, 18, 3, 0, 0, 0, 0, 0, 0,[74 chars]8, 6]

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Nov 1, 2023

Not exactly sure what's going on with the docs. They appear just fine with the doc-builder for me:

Screenshot from 2023-11-01 11-36-05

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Just some nits on formatting, docstrings and tests. Would like to have 👍 from @gante before merging to check the changes to generation logic.

tests/models/whisper/test_modeling_whisper.py Outdated Show resolved Hide resolved
tests/models/whisper/test_modeling_whisper.py Outdated Show resolved Hide resolved
tests/models/whisper/test_modeling_whisper.py Show resolved Hide resolved
tests/models/whisper/test_modeling_whisper.py Outdated Show resolved Hide resolved
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a more descriptive name than foo here? Or maybe just clarify in the comment that foo is said kwarg e.g. # Bart subclass with kwarg 'foo' that distorts the output

tests/models/whisper/test_modeling_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Impressive PR, it's very well contained!

The skeleton and API look good to me, I'm out of my depth for the generation logic.

@@ -945,6 +946,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
config: WhisperConfig
"""

main_input_name = "input_ids"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this an oversight in the implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WhisperDecoder was never tested as stand-alone before. It's also doesn't really make sense to use it alone because on always needs the encoded audio features

src/transformers/models/whisper/modeling_whisper.py Outdated Show resolved Hide resolved
patrickvonplaten and others added 3 commits November 1, 2023 13:57
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@patrickvonplaten
Copy link
Contributor Author

Merging as I think Joao is off today and changes in assisted generation are quite minimal IMO. @gante would be great if you could nevertheless take a look once back :-)

@patrickvonplaten patrickvonplaten merged commit 391d14e into main Nov 1, 2023
19 of 22 checks passed
@patrickvonplaten patrickvonplaten deleted the whisper_decoder_only branch November 1, 2023 15:01
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten All good on the generation front 👍

Cool strategy for handling the case of a shared encoder, I hope people realize that encoder-decoder LLMs may be viable (and faster) for input-grounded tasks

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…uggingface#27195)

* finish

* add tests

* fix all tests

* [Assistant Decoding] Add test

* fix more

* better

* finish

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* finish

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants