Skip to content

Commit

Permalink
Add gradient checkpointing to Whisper Flax (huggingface#22954)
Browse files Browse the repository at this point in the history
* Add gradient checkpointing to Whisper Flax

* self.gradient_checkpointing only needed in nn.Module, removing unnecessary comments
  • Loading branch information
versae authored and gojiteji committed Jun 5, 2023
1 parent 991ad06 commit a90ec9e
Showing 1 changed file with 59 additions and 21 deletions.
80 changes: 59 additions & 21 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
Expand Down Expand Up @@ -53,6 +54,8 @@
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_CONFIG_FOR_DOC = "WhisperConfig"

remat = nn_partitioning.remat


WHISPER_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down Expand Up @@ -387,16 +390,23 @@ def __call__(
return outputs


# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper
class FlaxWhisperEncoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperEncoderCheckpointLayer = remat(FlaxWhisperEncoderLayer, static_argnums=(2, 3))
self.layers = [
FlaxWhisperEncoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
else:
self.layers = [
FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop

def __call__(
Expand Down Expand Up @@ -531,16 +541,23 @@ def __call__(
return outputs


# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper
class FlaxWhisperDecoderLayerCollection(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False

def setup(self):
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
if self.gradient_checkpointing:
FlaxWhisperDecoderCheckpointLayer = remat(FlaxWhisperDecoderLayer, static_argnums=(4, 5, 6))
self.layers = [
FlaxWhisperDecoderCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
else:
self.layers = [
FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop

def __call__(
Expand Down Expand Up @@ -570,12 +587,12 @@ def __call__(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
output_attentions,
deterministic,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -605,6 +622,7 @@ def __call__(
class FlaxWhisperEncoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.conv1 = nn.Conv(
Expand All @@ -628,6 +646,7 @@ def setup(self) -> None:
self.layers = FlaxWhisperEncoderLayerCollection(
self.config,
dtype=self.dtype,
gradient_checkpointing=self.gradient_checkpointing,
)
self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype)

Expand Down Expand Up @@ -689,12 +708,15 @@ def __call__(
class FlaxWhisperDecoder(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype)
self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype)

self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype)
self.layers = FlaxWhisperDecoderLayerCollection(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

self.dropout_layer = nn.Dropout(rate=self.config.dropout)

Expand Down Expand Up @@ -753,10 +775,15 @@ def __call__(
class FlaxWhisperModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype)
self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype)
self.encoder = FlaxWhisperEncoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.decoder = FlaxWhisperDecoder(
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)

def __call__(
self,
Expand Down Expand Up @@ -821,11 +848,19 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
Expand Down Expand Up @@ -1137,9 +1172,12 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
class FlaxWhisperForConditionalGenerationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False

def setup(self) -> None:
self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype)
self.model = FlaxWhisperModule(
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False,
Expand Down

0 comments on commit a90ec9e

Please sign in to comment.