From c2e70787fbb61f9689405dc0ad3395efa1ccd7a8 Mon Sep 17 00:00:00 2001 From: Karim Foda <35491698+KMFODA@users.noreply.github.com> Date: Sun, 14 Aug 2022 17:27:13 +0200 Subject: [PATCH] Flax Remat for LongT5 (#17994) * [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * add gradient_checkpointing to examples * Add gradient_checkpointing to run_mlm_flax * Add remat to longt5 * Add gradient checkpointing test longt5 * Fix args errors * Fix remaining tests * Make fixup & quality fixes * replace kwargs * remove unecessary kwargs * Make fixup changes * revert long_t5_flax changes * Remove return_dict and copy to LongT5 * Remove test_gradient_checkpointing Co-authored-by: sanchit-gandhi --- .../flax/language-modeling/run_mlm_flax.py | 9 ++ .../summarization/run_summarization_flax.py | 9 ++ .../models/longt5/modeling_flax_longt5.py | 75 +++++++++++---- .../models/t5/modeling_flax_t5.py | 95 +++++++++++++++---- 4 files changed, 149 insertions(+), 39 deletions(-) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 65f6a2285d9c34..408e09fc111cb3 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -107,6 +107,12 @@ class TrainingArguments: default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) def __post_init__(self): if self.output_dir is not None: @@ -640,6 +646,9 @@ def group_texts(examples): dtype=getattr(jnp, model_args.dtype), ) + if training_args.gradient_checkpointing: + model.enable_gradient_checkpointing() + # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index c193fe0bc3745a..2813c88a3bd6fd 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -121,6 +121,12 @@ class TrainingArguments: default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) def __post_init__(self): if self.output_dir is not None: @@ -535,6 +541,9 @@ def main(): dtype=getattr(jnp, model_args.dtype), ) + if training_args.gradient_checkpointing: + model.enable_gradient_checkpointing() + if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 766dc36888e228..224515cd12a200 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey @@ -53,6 +54,8 @@ _CONFIG_FOR_DOC = "LongT5Config" _TOKENIZER_FOR_DOC = "T5Tokenizer" +remat = nn_partitioning.remat + # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: @@ -1356,7 +1359,6 @@ def __call__( encoder_attention_mask=None, encoder_decoder_position_bias=None, output_attentions=False, - return_dict=True, deterministic=True, init_cache=False, ): @@ -1377,13 +1379,31 @@ def __call__( class FlaxLongT5BlockCollection(nn.Module): config: LongT5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): self.causal = self.config.causal - self.blocks = [ - FlaxLongT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i)) - for i in range(self.config.num_layers) - ] + if self.gradient_checkpointing: + FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8)) + self.blocks = [ + FlaxLongT5CheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxLongT5LayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] def __call__( self, @@ -1409,14 +1429,14 @@ def __call__( layer_outputs = layer_module( hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, ) hidden_states = layer_outputs[0] @@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module): config: LongT5Config embed_tokens: nn.Embed dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): self.causal = self.config.causal - self.block = FlaxLongT5BlockCollection(self.config, dtype=self.dtype) + self.block = FlaxLongT5BlockCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.final_layer_norm = FlaxLongT5LayerNorm( self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype ) @@ -1989,6 +2012,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs class FlaxLongT5Module(nn.Module): config: LongT5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def _get_encoder_module(self): return self.encoder @@ -2005,12 +2029,22 @@ def setup(self): encoder_config = copy.deepcopy(self.config) encoder_config.causal = False - self.encoder = FlaxLongT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) + self.encoder = FlaxLongT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) decoder_config = copy.deepcopy(self.config) decoder_config.causal = True decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxLongT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype) + self.decoder = FlaxLongT5Stack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -2104,6 +2138,7 @@ class FlaxLongT5Model(FlaxLongT5PreTrainedModel): class FlaxLongT5ForConditionalGenerationModule(nn.Module): config: LongT5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def _get_encoder_module(self): return self.encoder @@ -2124,13 +2159,17 @@ def setup(self): encoder_config.causal = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = FlaxLongT5Stack(encoder_config, self.shared, dtype=self.dtype) + self.encoder = FlaxLongT5Stack( + encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) decoder_config = copy.deepcopy(self.config) decoder_config.causal = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxLongT5Stack(decoder_config, self.shared, dtype=self.dtype) + self.decoder = FlaxLongT5Stack( + decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.lm_head = nn.Dense( self.config.vocab_size, diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 06ad5105429767..918a605fc4813a 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey @@ -53,6 +54,8 @@ _CONFIG_FOR_DOC = "T5Config" _TOKENIZER_FOR_DOC = "T5Tokenizer" +remat = nn_partitioning.remat + # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: @@ -622,7 +625,6 @@ def __call__( encoder_attention_mask=None, encoder_decoder_position_bias=None, output_attentions=False, - return_dict=True, deterministic=True, init_cache=False, ): @@ -642,13 +644,31 @@ def __call__( class FlaxT5BlockCollection(nn.Module): config: T5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): self.causal = self.config.causal - self.blocks = [ - FlaxT5LayerCollection(self.config, has_relative_attention_bias=(i == 0), dtype=self.dtype, name=str(i)) - for i in range(self.config.num_layers) - ] + if self.gradient_checkpointing: + FlaxT5CheckpointLayer = remat(FlaxT5LayerCollection, static_argnums=(6, 7, 8)) + self.blocks = [ + FlaxT5CheckpointLayer( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] + else: + self.blocks = [ + FlaxT5LayerCollection( + self.config, + has_relative_attention_bias=(i == 0), + dtype=self.dtype, + name=str(i), + ) + for i in range(self.config.num_layers) + ] def __call__( self, @@ -674,14 +694,14 @@ def __call__( layer_outputs = layer_module( hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - output_attentions=output_attentions, - deterministic=deterministic, - init_cache=init_cache, + attention_mask, + position_bias, + encoder_hidden_states, + encoder_attention_mask, + encoder_decoder_position_bias, + output_attentions, + deterministic, + init_cache, ) hidden_states = layer_outputs[0] @@ -711,11 +731,14 @@ class FlaxT5Stack(nn.Module): config: T5Config embed_tokens: nn.Embed dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): self.causal = self.config.causal - self.block = FlaxT5BlockCollection(self.config, dtype=self.dtype) + self.block = FlaxT5BlockCollection( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.final_layer_norm = FlaxT5LayerNorm( self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype ) @@ -919,11 +942,19 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -1248,6 +1279,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs class FlaxT5Module(nn.Module): config: T5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def _get_encoder_module(self): return self.encoder @@ -1264,12 +1296,22 @@ def setup(self): encoder_config = copy.deepcopy(self.config) encoder_config.causal = False - self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) decoder_config = copy.deepcopy(self.config) decoder_config.causal = True decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype) + self.decoder = FlaxT5Stack( + decoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -1364,6 +1406,7 @@ class FlaxT5Model(FlaxT5PreTrainedModel): class FlaxT5EncoderModule(nn.Module): config: T5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): self.shared = nn.Embed( @@ -1376,7 +1419,12 @@ def setup(self): encoder_config.is_decoder = False encoder_config.is_encoder_decoder = False encoder_config.causal = False - self.encoder = FlaxT5Stack(encoder_config, embed_tokens=self.shared, dtype=self.dtype) + self.encoder = FlaxT5Stack( + encoder_config, + embed_tokens=self.shared, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -1384,7 +1432,7 @@ def __call__( attention_mask=None, output_attentions=False, output_hidden_states=False, - return_dict=True, + return_dict: bool = True, deterministic: bool = True, ): @@ -1445,6 +1493,7 @@ def __call__( class FlaxT5ForConditionalGenerationModule(nn.Module): config: T5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def _get_encoder_module(self): return self.encoder @@ -1465,13 +1514,17 @@ def setup(self): encoder_config.causal = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = FlaxT5Stack(encoder_config, self.shared, dtype=self.dtype) + self.encoder = FlaxT5Stack( + encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) decoder_config = copy.deepcopy(self.config) decoder_config.causal = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = self.config.num_decoder_layers - self.decoder = FlaxT5Stack(decoder_config, self.shared, dtype=self.dtype) + self.decoder = FlaxT5Stack( + decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.lm_head = nn.Dense( self.config.vocab_size,