Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FlaxWhisperForAudioClassification model #23173

Merged
merged 6 commits into from
May 5, 2023

Conversation

raghavanone
Copy link
Contributor

Fixes #21779

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 5, 2023

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Collaborator

sgugger commented May 5, 2023

cc @sanchit-gandhi

@sgugger
Copy link
Collaborator

sgugger commented May 5, 2023

The test failures are appearing on this one. Let's fix them and re-merge!

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented May 5, 2023

We need to make two changes following the updates in #22954! First, we need to assign the attribute gradient_checkpointing to the class FlaxWhisperForAudioClassificationModule, similar to what we do for FlaxWhisperForConditionalGeneration:

gradient_checkpointing: bool = False

We then need to forward self.gradient_checkpointing to the encoder:

-         self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
+         self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)

This will facilitate gradient checkpointing for the module!

@raghavanone
Copy link
Contributor Author

@sgugger @sanchit-gandhi Done, all tests pass !

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thnaks a lot!

@sgugger sgugger merged commit 312b104 into huggingface:main May 5, 2023
@@ -1512,6 +1512,7 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
class FlaxWhisperForAudioClassificationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @raghavanone! Sorry I didn't get the chance to re-review the last changes before merge, there's one small change we need in this line to forward the gradient checkpointing attribute to the encoder (see #23173 (comment)):

Suggested change
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)

Would you like to open a new PR to add this one line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes , will do

gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* Add FlaxWhisperForAudioClassification model

* Add models to init

* Add models to init

* Fix copies

* Fix automapping

* Fix failing test
@@ -1430,7 +1430,7 @@ def __init__(
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=30,
max_source_positions=1500,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @raghavanone @sanchit-gandhi

I am wondering why we increase these values a lot ...?

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whisper works intrinsically on a sequence length of 30s inputs (which corresponds to 1500 log mel spectrogram frames)

We could use a shorter context window (i.e. 30s -> 15s), we just need to initialise the weights accordingly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #24105

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Add FlaxWhisperForAudioClassification model

* Add models to init

* Add models to init

* Fix copies

* Fix automapping

* Fix failing test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Flax Whisper for audio classification
5 participants