Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient checkpointing to Whisper Flax #22954

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 63 additions & 21 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
Expand Down Expand Up @@ -53,6 +54,8 @@
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_CONFIG_FOR_DOC = "WhisperConfig"

remat = nn_partitioning.remat


WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down Expand Up @@ -387,16 +390,24 @@ def __call__(
return outputs


# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper
# Adapted from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper
versae marked this conversation as resolved.
Show resolved Hide resolved
class FlaxWhisperEncoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
self.layers = [
FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
else:
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop

def __call__(
Expand Down Expand Up @@ -531,16 +542,24 @@ def __call__(
return outputs


# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper
# Adapted from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper
versae marked this conversation as resolved.
Show resolved Hide resolved
class FlaxWhisperDecoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
self.layers = [
FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
else:
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop

def __call__(
Expand Down Expand Up @@ -570,12 +589,12 @@ def __call__(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
attention_mask,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: remat does not support key-word arguments, hence the need to change to pure arguments

encoder_hidden_states,
encoder_attention_mask,
init_cache,
output_attentions,
deterministic,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -605,6 +624,7 @@ def __call__(
class FlaxWhisperEncoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.conv1 = nn.Conv(
Expand All @@ -628,6 +648,7 @@ def setup(self) -> None:
self.layers = FlaxWhisperEncoderLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)

Expand Down Expand Up @@ -689,12 +710,15 @@ def __call__(
class FlaxWhisperDecoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)

self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype)
self.layers = FlaxWhisperDecoderLayerCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

self.dropout_layer = nn.Dropout(rate=self.config.dropout)

Expand Down Expand Up @@ -753,10 +777,15 @@ def __call__(
class FlaxWhisperModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype)
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype)
self.encoder = FlaxWhisperEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.decoder = FlaxWhisperDecoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

def __call__(
self,
Expand Down Expand Up @@ -821,11 +850,21 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
self.gradient_checkpointing = gradient_checkpointing
versae marked this conversation as resolved.
Show resolved Hide resolved
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
versae marked this conversation as resolved.
Show resolved Hide resolved
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
versae marked this conversation as resolved.
Show resolved Hide resolved
)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
Expand Down Expand Up @@ -1137,9 +1176,12 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
class FlaxWhisperForConditionalGenerationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype)
self.model = FlaxWhisperModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
Expand Down