-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FlaxWhisperForAudioClassification model #23173
Conversation
The documentation is not available anymore as the PR was closed or merged. |
The test failures are appearing on this one. Let's fix them and re-merge! |
We need to make two changes following the updates in #22954! First, we need to assign the attribute
We then need to forward self.gradient_checkpointing to the encoder:
- self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
+ self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) This will facilitate gradient checkpointing for the module! |
@sgugger @sanchit-gandhi Done, all tests pass ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thnaks a lot!
@@ -1512,6 +1512,7 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): | |||
class FlaxWhisperForAudioClassificationModule(nn.Module): | |||
config: WhisperConfig | |||
dtype: jnp.dtype = jnp.float32 | |||
gradient_checkpointing: bool = False | |||
|
|||
def setup(self) -> None: | |||
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @raghavanone! Sorry I didn't get the chance to re-review the last changes before merge, there's one small change we need in this line to forward the gradient checkpointing attribute to the encoder (see #23173 (comment)):
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype) | |
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) |
Would you like to open a new PR to add this one line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes , will do
* Add FlaxWhisperForAudioClassification model * Add models to init * Add models to init * Fix copies * Fix automapping * Fix failing test
@@ -1430,7 +1430,7 @@ def __init__( | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=20, | |||
max_source_positions=30, | |||
max_source_positions=1500, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @raghavanone @sanchit-gandhi
I am wondering why we increase these values a lot ...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whisper works intrinsically on a sequence length of 30s inputs (which corresponds to 1500 log mel spectrogram frames)
We could use a shorter context window (i.e. 30s -> 15s), we just need to initialise the weights accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in #24105
* Add FlaxWhisperForAudioClassification model * Add models to init * Add models to init * Fix copies * Fix automapping * Fix failing test
Fixes #21779