Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax Remat for LongT5 #17994

Merged
merged 21 commits into from
Aug 14, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)

def __post_init__(self):
if self.output_dir is not None:
Expand Down Expand Up @@ -635,6 +641,9 @@ def group_texts(examples):
use_auth_token=True if model_args.use_auth_token else None,
)

if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()

# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
Expand Down
9 changes: 9 additions & 0 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)

def __post_init__(self):
if self.output_dir is not None:
Expand Down Expand Up @@ -532,6 +538,9 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)

if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()

if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

Expand Down
29 changes: 27 additions & 2 deletions src/transformers/models/longt5/modeling_flax_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
Expand Down Expand Up @@ -53,6 +54,7 @@
_CONFIG_FOR_DOC = "LongT5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"

remat = nn_partitioning.remat

# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
Expand Down Expand Up @@ -119,7 +121,8 @@ def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray:


def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
"""Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
"""Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius.
"""
relative_position_ids = _make_3block_relative_position_ids(block_len)
locality_mask = jnp.abs(relative_position_ids) < block_len
locality_mask = locality_mask[None, None, :, :]
Expand Down Expand Up @@ -777,6 +780,7 @@ def __call__(

# create dropout rng
dropout_rng = None
# breakpoint()
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")

Expand Down Expand Up @@ -1124,6 +1128,7 @@ def __call__(
**kwargs: Any, # to accept init_cache kwargs
):
normed_hidden_states = self.layer_norm(hidden_states)
# breakpoint()
attention_output = self.LocalSelfAttention(
normed_hidden_states,
attention_mask=attention_mask,
Expand Down Expand Up @@ -1658,11 +1663,25 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
gradient_checkpointing: bool = False,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
module = self.module_class(
config=config,
dtype=dtype,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
# breakpoint()
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
Expand All @@ -1673,11 +1692,17 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
# breakpoint()

encoder_hidden_states = None
encoder_attention_mask = None

random_params = self.module.init(
rngs,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
decoder_input_ids,
decoder_attention_mask,
)["params"]
Expand Down
Loading