From 61400e1ec7898892e77314dd819c1a1a52bd5268 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 7 Jul 2021 08:50:38 +0200 Subject: [PATCH] [Flax] Add FlaxMBart (#12236) * Copy BART to MBart and rename some stuff * Add copy statements pointing to FlaxBart * Update/add some common files * Update shift_tokens_rigth + fix imports * Fix shift_tokens_right method according to MBart implementation * Update shift_tokens_right in tests accordingly * Fix the import issue and update docs file * make style quality * Do some minor changes according to patil-suraj suggestions * Change the order of normalization layer and attention * Add some copu statementes * Update generate method and add integration test for mBart * Make a few updates after a review Besides, add `lang_code_to_id` to MBartTokenizeFast * fix-copies; make style quality * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * fix output type, style * add copied from * resolve conflicts Co-authored-by: Suraj Patil --- docs/source/index.rst | 2 +- docs/source/model_doc/mbart.rst | 28 + src/transformers/__init__.py | 16 + .../models/auto/modeling_flax_auto.py | 12 + src/transformers/models/mbart/__init__.py | 19 + .../models/mbart/modeling_flax_mbart.py | 1748 +++++++++++++++++ .../models/mbart/tokenization_mbart_fast.py | 3 + src/transformers/utils/dummy_flax_objects.py | 45 + tests/test_modeling_flax_mbart.py | 464 +++++ 9 files changed, 2336 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/mbart/modeling_flax_mbart.py create mode 100644 tests/test_modeling_flax_mbart.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 31088e1ff1a..9a6dd5d1e5b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -421,7 +421,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | XLNet | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| mBART | ✅ | ✅ | ✅ | ✅ | ❌ | +| mBART | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | mT5 | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/model_doc/mbart.rst b/docs/source/model_doc/mbart.rst index a94cd385b10..dbe1d4e435d 100644 --- a/docs/source/model_doc/mbart.rst +++ b/docs/source/model_doc/mbart.rst @@ -240,3 +240,31 @@ TFMBartForConditionalGeneration .. autoclass:: transformers.TFMBartForConditionalGeneration :members: call + + +FlaxMBartModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxMBartModel + :members: __call__, encode, decode + + +FlaxMBartForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxMBartForConditionalGeneration + :members: __call__, encode, decode + + +FlaxMBartForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxMBartForSequenceClassification + :members: __call__, encode, decode + + +FlaxMBartForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxMBartForQuestionAnswering + :members: __call__, encode, decode diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 48774de079a..f539484a2c2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1633,6 +1633,15 @@ _import_structure["models.gpt_neo"].extend( ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] ) + _import_structure["models.mbart"].extend( + [ + "FlaxMBartForConditionalGeneration", + "FlaxMBartForQuestionAnswering", + "FlaxMBartForSequenceClassification", + "FlaxMBartModel", + "FlaxMBartPreTrainedModel", + ] + ) _import_structure["models.roberta"].extend( [ "FlaxRobertaForMaskedLM", @@ -3019,6 +3028,13 @@ ) from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel + from .models.mbart import ( + FlaxMBartForConditionalGeneration, + FlaxMBartForQuestionAnswering, + FlaxMBartForSequenceClassification, + FlaxMBartModel, + FlaxMBartPreTrainedModel, + ) from .models.roberta import ( FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 72a1c2ef6d8..8b9584fcc91 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -55,6 +55,12 @@ ) from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel +from ..mbart.modeling_flax_mbart import ( + FlaxMBartForConditionalGeneration, + FlaxMBartForQuestionAnswering, + FlaxMBartForSequenceClassification, + FlaxMBartModel, +) from ..roberta.modeling_flax_roberta import ( FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, @@ -75,6 +81,7 @@ ElectraConfig, GPT2Config, GPTNeoConfig, + MBartConfig, RobertaConfig, T5Config, ViTConfig, @@ -97,6 +104,7 @@ (ElectraConfig, FlaxElectraModel), (CLIPConfig, FlaxCLIPModel), (ViTConfig, FlaxViTModel), + (MBartConfig, FlaxMBartModel), (T5Config, FlaxT5Model), (Wav2Vec2Config, FlaxWav2Vec2Model), ] @@ -110,6 +118,7 @@ (BigBirdConfig, FlaxBigBirdForPreTraining), (BartConfig, FlaxBartForConditionalGeneration), (ElectraConfig, FlaxElectraForPreTraining), + (MBartConfig, FlaxMBartForConditionalGeneration), (T5Config, FlaxT5ForConditionalGeneration), (Wav2Vec2Config, FlaxWav2Vec2ForPreTraining), ] @@ -123,6 +132,7 @@ (BigBirdConfig, FlaxBigBirdForMaskedLM), (BartConfig, FlaxBartForConditionalGeneration), (ElectraConfig, FlaxElectraForMaskedLM), + (MBartConfig, FlaxMBartForConditionalGeneration), ] ) @@ -157,6 +167,7 @@ (BigBirdConfig, FlaxBigBirdForSequenceClassification), (BartConfig, FlaxBartForSequenceClassification), (ElectraConfig, FlaxElectraForSequenceClassification), + (MBartConfig, FlaxMBartForSequenceClassification), ] ) @@ -168,6 +179,7 @@ (BigBirdConfig, FlaxBigBirdForQuestionAnswering), (BartConfig, FlaxBartForQuestionAnswering), (ElectraConfig, FlaxElectraForQuestionAnswering), + (MBartConfig, FlaxMBartForQuestionAnswering), ] ) diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index 414d33a9fa7..1c7c41704ad 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -19,6 +19,7 @@ from ...file_utils import ( _BaseLazyModule, + is_flax_available, is_sentencepiece_available, is_tf_available, is_tokenizers_available, @@ -56,6 +57,15 @@ "TFMBartPreTrainedModel", ] +if is_flax_available(): + _import_structure["modeling_flax_mbart"] = [ + "FlaxMBartForConditionalGeneration", + "FlaxMBartForQuestionAnswering", + "FlaxMBartForSequenceClassification", + "FlaxMBartModel", + "FlaxMBartPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig @@ -82,6 +92,15 @@ if is_tf_available(): from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel + if is_flax_available(): + from .modeling_flax_mbart import ( + FlaxMBartForConditionalGeneration, + FlaxMBartForQuestionAnswering, + FlaxMBartForSequenceClassification, + FlaxMBartModel, + FlaxMBartPreTrainedModel, + ) + else: import importlib import os diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py new file mode 100644 index 00000000000..fd8e64ca0a7 --- /dev/null +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -0,0 +1,1748 @@ +# coding=utf-8 +# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax MBart model. """ + +import math +import random +from functools import partial +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from jax import lax +from jax.random import PRNGKey + +from ...file_utils import add_start_docstrings, replace_return_docstrings +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, + FlaxSeq2SeqQuestionAnsweringModelOutput, + FlaxSeq2SeqSequenceClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import logging +from .configuration_mbart import MBartConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25" +_CONFIG_FOR_DOC = "MBartConfig" +_TOKENIZER_FOR_DOC = "MBartTokenizer" + + +MBART_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading or saving, resizing the input + embeddings, pruning heads etc.) + + This model is also a Flax Linen `flax.nn.Module + `__ subclass. Use it as a regular Flax + Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ + + Parameters: + config (:class:`~transformers.MBartConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the + model weights. +""" + +MBART_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.MBartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.MBartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no + :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to + the right for denoising pre-training following the paper. + decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in `the paper + `__ for more information on the default strategy. + position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range ``[0, config.max_position_embeddings - 1]``. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +MBART_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.MBartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +MBART_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.MBartTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no + :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to + the right for denoising pre-training following the paper. + encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`): + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + encoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should modify to your needs. See diagram 1 in `the paper + `__ for more information on the default strategy. + decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range ``[0, config.max_position_embeddings - 1]``. + past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray: + """ + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not + have a single `decoder_start_token_id` in contrast to other Bart-like models. + """ + prev_output_tokens = jnp.array(input_ids).clone() + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids) + index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1) + decoder_start_tokens = jnp.array( + [prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)] + ).squeeze() + # for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + for i in range(prev_output_tokens.shape[1], 0, -1): + prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., i), prev_output_tokens[:, i - 1]) + prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., 0), decoder_start_tokens) + + return prev_output_tokens + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->MBart +class FlaxMBartAttention(nn.Module): + config: MBartConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + assert ( + self.head_dim * self.num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + # self_attention + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class FlaxMBartEncoderLayer(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxMBartAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection with Bart->MBart +class FlaxMBartEncoderLayerCollection(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxMBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxMBartDecoderLayer(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxMBartAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.encoder_attn = FlaxMBartAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->MBart +class FlaxMBartDecoderLayerCollection(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxMBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartClassificationHead with Bart->MBart +class FlaxMBartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + config: MBartConfig + inner_dim: int + num_classes: int + pooler_dropout: float + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + ) + self.dropout = nn.Dropout(rate=self.pooler_dropout) + self.out_proj = nn.Dense( + self.num_classes, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + + def __call__(self, hidden_states: jnp.ndarray, deterministic: bool): + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = jnp.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxMBartEncoder(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + embed_tokens: Optional[nn.Embed] = None + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_source_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0 + + if self.embed_tokens is None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + + # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + self.embed_positions = nn.Embed( + self.config.max_position_embeddings + self.offset, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype) + self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(position_ids + self.offset) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + if not return_dict: + return (last_hidden_states,) + outputs[1:] + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_states, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxMBartDecoder(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + embed_tokens: Optional[nn.Embed] = None + + def setup(self): + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + embed_dim = self.config.d_model + self.padding_idx = self.config.pad_token_id + self.max_target_positions = self.config.max_position_embeddings + self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 + + if self.embed_tokens is None: + self.embed_tokens = nn.Embed( + self.config.vocab_size, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + + # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + self.embed_positions = nn.Embed( + self.config.max_position_embeddings + self.offset, + embed_dim, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + + self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype) + self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype) + self.layer_norm = nn.LayerNorm(dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + input_shape = input_ids.shape + input_ids = input_ids.reshape(-1, input_shape[-1]) + + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed positions + positions = self.embed_positions(position_ids + self.offset) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + if not return_dict: + return (last_hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_states, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModule with Bart->MBart +class FlaxMBartModule(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.shared = nn.Embed( + self.config.vocab_size, + self.config.d_model, + embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + dtype=self.dtype, + ) + + self.encoder = FlaxMBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + self.decoder = FlaxMBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): + config_class = MBartConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: MBartConfig, + input_shape: Tuple[int] = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + # make sure initialization pass will work for FlaxMBartForSequenceClassificationModule + input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = input_ids + decoder_attention_mask = jnp.ones_like(input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (:obj:`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (:obj:`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + ``encoder_outputs`` consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, + `optional`: :obj:`attentions`). :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, + hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the + encoder. Used in the cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(MBART_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=MBartConfig) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration + + >>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25') + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='jax') + >>> encoder_outputs = model.encode(**inputs) + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=MBartConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration + + >>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25') + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='jax') + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxMBartAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare MBart Model transformer outputting raw hidden-states without any specific head on top.", + MBART_START_DOCSTRING, +) +class FlaxMBartModel(FlaxMBartPreTrainedModel): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxMBartModule + + +append_call_sample_docstring( + FlaxMBartModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC +) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule with Bart->MBart +class FlaxMBartForConditionalGenerationModule(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.model.shared.num_embeddings, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype), + ) + self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.variables["params"]["shared"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + lm_logits += self.final_logits_bias + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + "The MMBart Model with a language modeling head. Can be used for summarization.", MBART_START_DOCSTRING +) +class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel): + module_class = FlaxMBartForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(MBART_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=MBartConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + deterministic: bool = True, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example:: + + >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration + + >>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25') + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25') + + >>> text = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer(text, max_length=1024, return_tensors='jax') + >>> encoder_outputs = model.encode(**inputs) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxMBartAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.variables["params"]["shared"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + lm_logits += module.final_logits_bias + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jnp.DeviceArray] = None, + encoder_outputs=None, + **kwargs + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Summarization example:: + + >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration + + >>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25') + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='jax') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids'], decoder_start_token_id=tokenizer.lang_code_to_id[tgt_lang]).sequences + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + + Mask filling example:: + + >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration + >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25') + >>> TXT = "My friends are but they eat too many carbs." + + >>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25') + >>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids'] + >>> logits = model(input_ids).logits + + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) + >>> values, predictions = jax.lax.top_k(probs) + + >>> tokenizer.decode(predictions).split() +""" + +overwrite_call_docstring( + FlaxMBartForConditionalGeneration, MBART_INPUTS_DOCSTRING + FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxMBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForSequenceClassificationModule with Bart->MBart +class FlaxMBartForSequenceClassificationModule(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 + num_labels: Optional[int] = None + + def setup(self): + self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) + self.classification_head = FlaxMBartClassificationHead( + config=self.config, + inner_dim=self.config.d_model, + num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels, + pooler_dropout=self.config.classifier_dropout, + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] # last hidden state + + eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0) + + # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation + if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer: + if len(jnp.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + + if any(eos_mask.sum(1) == 0): + raise ValueError("There are missing tokens in input_ids") + + # Ensure to keep 1 only for the last token for each example + eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6 + eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0) + + sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1) + logits = self.classification_head(sentence_representation, deterministic=deterministic) + + if not return_dict: + output = (logits,) + outputs[1:] + return output + + return FlaxSeq2SeqSequenceClassifierOutput( + logits=logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE + tasks. + """, + MBART_START_DOCSTRING, +) +class FlaxMBartForSequenceClassification(FlaxMBartPreTrainedModel): + module_class = FlaxMBartForSequenceClassificationModule + dtype = jnp.float32 + + +append_call_sample_docstring( + FlaxMBartForSequenceClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSeq2SeqSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartForQuestionAnsweringModule with Bart->MBart +class FlaxMBartForQuestionAnsweringModule(nn.Module): + config: MBartConfig + dtype: jnp.dtype = jnp.float32 + num_labels = 2 + + def setup(self): + self.model = FlaxMBartModule(config=self.config, dtype=self.dtype) + self.qa_outputs = nn.Dense( + self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype) + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + position_ids=position_ids, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return output + + return FlaxSeq2SeqQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + MBart Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MBART_START_DOCSTRING, +) +class FlaxMBartForQuestionAnswering(FlaxMBartPreTrainedModel): + module_class = FlaxMBartForQuestionAnsweringModule + dtype = jnp.float32 + + +append_call_sample_docstring( + FlaxMBartForQuestionAnswering, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSeq2SeqQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/models/mbart/tokenization_mbart_fast.py b/src/transformers/models/mbart/tokenization_mbart_fast.py index 33cbd678e8f..94f5eda640e 100644 --- a/src/transformers/models/mbart/tokenization_mbart_fast.py +++ b/src/transformers/models/mbart/tokenization_mbart_fast.py @@ -139,6 +139,9 @@ def __init__( ) self.add_special_tokens({"additional_special_tokens": _additional_special_tokens}) + self.lang_code_to_id = { + lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES + } self._src_lang = src_lang if src_lang is not None else "en_XX" self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index f935d32c310..3641a42c933 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -570,6 +570,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) +class FlaxMBartForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxMBartForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxMBartForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxMBartModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxMBartPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxRobertaForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/tests/test_modeling_flax_mbart.py b/tests/test_modeling_flax_mbart.py new file mode 100644 index 00000000000..007df91ab1e --- /dev/null +++ b/tests/test_modeling_flax_mbart.py @@ -0,0 +1,464 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import timeout_decorator # noqa + +from transformers import MBartConfig, is_flax_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow + +from .test_generation_flax_utils import FlaxGenerationTesterMixin +from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor + + +if is_flax_available(): + import os + + # The slow tests are often failing with OOM error on GPU + # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed + # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html + os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" + + import jax + import jax.numpy as jnp + from transformers import AutoTokenizer + from transformers.models.mbart.modeling_flax_mbart import ( + FlaxMBartForConditionalGeneration, + FlaxMBartForQuestionAnswering, + FlaxMBartForSequenceClassification, + FlaxMBartModel, + shift_tokens_right, + ) + + +def prepare_mbart_inputs_dict( + config, + input_ids, + decoder_input_ids=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, +): + if attention_mask is None: + attention_mask = np.where(input_ids != config.pad_token_id, 1, 0) + if decoder_attention_mask is None: + decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0) + if head_mask is None: + head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads)) + if decoder_head_mask is None: + decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads)) + if cross_attn_head_mask is None: + cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads)) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + } + + +class FlaxMBartModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=4, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=32, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + decoder_start_token_id=2, + initializer_range=0.02, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.decoder_start_token_id = decoder_start_token_id + self.initializer_range = initializer_range + + def prepare_config_and_inputs(self): + input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size) + input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1) + + decoder_input_ids = shift_tokens_right(input_ids, 1) + + config = MBartConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + initializer_range=self.initializer_range, + use_cache=False, + ) + inputs_dict = prepare_mbart_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def check_use_cache_forward(self, model_class_name, config, inputs_dict): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(inputs_dict["input_ids"]) + + decoder_input_ids, decoder_attention_mask = ( + inputs_dict["decoder_input_ids"], + inputs_dict["decoder_attention_mask"], + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], + (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + ) + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=past_key_values, + decoder_position_ids=decoder_position_ids, + ) + + decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + past_key_values=outputs_cache.past_key_values, + decoder_position_ids=decoder_position_ids, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict): + max_decoder_length = 20 + model = model_class_name(config) + + encoder_outputs = model.encode(inputs_dict["input_ids"]) + + decoder_input_ids, decoder_attention_mask = ( + inputs_dict["decoder_input_ids"], + inputs_dict["decoder_attention_mask"], + ) + + decoder_attention_mask_cache = jnp.concatenate( + [ + decoder_attention_mask, + jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), + ], + axis=-1, + ) + + past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :], + (decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1), + ) + + outputs_cache = model.decode( + decoder_input_ids[:, :-1], + encoder_outputs, + decoder_attention_mask=decoder_attention_mask_cache, + past_key_values=past_key_values, + decoder_position_ids=decoder_position_ids, + ) + decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model.decode( + decoder_input_ids[:, -1:], + encoder_outputs, + past_key_values=outputs_cache.past_key_values, + decoder_attention_mask=decoder_attention_mask_cache, + decoder_position_ids=decoder_position_ids, + ) + + outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask) + + diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) + self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") + + +@require_flax +class MBartHeadTests(unittest.TestCase): + vocab_size = 99 + + def _get_config_and_data(self): + input_ids = np.array( + [ + [71, 82, 18, 33, 46, 91, 2], + [68, 34, 26, 58, 30, 82, 2], + [5, 97, 17, 39, 94, 40, 2], + [76, 83, 94, 25, 70, 78, 2], + [87, 59, 41, 35, 48, 66, 2], + [55, 13, 16, 58, 5, 2, 1], # note padding + [64, 27, 31, 51, 12, 75, 2], + [52, 64, 86, 17, 83, 39, 2], + [48, 61, 9, 24, 71, 82, 2], + [26, 1, 60, 48, 22, 13, 2], + [21, 5, 62, 28, 14, 76, 2], + [45, 98, 37, 86, 59, 48, 2], + [70, 70, 50, 9, 28, 0, 2], + ], + dtype=np.int64, + ) + + batch_size = input_ids.shape[0] + config = MBartConfig( + vocab_size=self.vocab_size, + d_model=24, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + max_position_embeddings=48, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, + ) + return config, input_ids, batch_size + + def test_sequence_classification_forward(self): + config, input_ids, batch_size = self._get_config_and_data() + model = FlaxMBartForSequenceClassification(config) + outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) + expected_shape = (batch_size, config.num_labels) + self.assertEqual(outputs["logits"].shape, expected_shape) + + def test_question_answering_forward(self): + config, input_ids, batch_size = self._get_config_and_data() + model = FlaxMBartForQuestionAnswering(config) + outputs = model(input_ids=input_ids) + + self.assertEqual(outputs["start_logits"].shape, input_ids.shape) + self.assertEqual(outputs["end_logits"].shape, input_ids.shape) + + # @timeout_decorator.timeout(1) # not working with the decorator so far + def test_lm_forward(self): + config, input_ids, batch_size = self._get_config_and_data() + lm_model = FlaxMBartForConditionalGeneration(config) + outputs = lm_model(input_ids=input_ids) + expected_shape = (batch_size, input_ids.shape[1], config.vocab_size) + self.assertEqual(outputs["logits"].shape, expected_shape) + + def test_lm_uneven_forward(self): + config = MBartConfig( + vocab_size=self.vocab_size, + d_model=14, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=2, + decoder_attention_heads=2, + encoder_ffn_dim=8, + decoder_ffn_dim=8, + max_position_embeddings=48, + ) + lm_model = FlaxMBartForConditionalGeneration(config) + context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64) + summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64) + outputs = lm_model(input_ids=context, decoder_input_ids=summary) + expected_shape = (*summary.shape, config.vocab_size) + self.assertEqual(outputs["logits"].shape, expected_shape) + + def test_shift_tokens_right(self): + input_ids = np.array([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=np.int64) + shifted = shift_tokens_right(input_ids, 1) + n_pad_before = np.equal(input_ids, 1).astype(np.float32).sum() + n_pad_after = np.equal(shifted, 1).astype(np.float32).sum() + self.assertEqual(shifted.shape, input_ids.shape) + self.assertEqual(n_pad_after, n_pad_before - 1) + self.assertTrue(np.equal(shifted[:, 0], 2).all()) + + +@require_flax +class FlaxMBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin): + is_encoder_decoder = True + all_model_classes = ( + ( + FlaxMBartModel, + FlaxMBartForConditionalGeneration, + FlaxMBartForSequenceClassification, + FlaxMBartForQuestionAnswering, + ) + if is_flax_available() + else () + ) + all_generative_model_classes = (FlaxMBartForConditionalGeneration,) if is_flax_available() else () + + def setUp(self): + self.model_tester = FlaxMBartModelTester(self) + + def test_use_cache_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward(model_class, config, inputs_dict) + + def test_use_cache_forward_with_attn_mask(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict) + + def test_encode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def encode_jitted(input_ids, attention_mask=None, **kwargs): + return model.encode(input_ids=input_ids, attention_mask=attention_mask) + + with self.subTest("JIT Enabled"): + jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = encode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + def test_decode(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + model = model_class(config) + encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + + prepared_inputs_dict = { + "decoder_input_ids": inputs_dict["decoder_input_ids"], + "decoder_attention_mask": inputs_dict["decoder_attention_mask"], + "encoder_outputs": encoder_outputs, + } + + @jax.jit + def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs): + return model.decode( + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + ) + + with self.subTest("JIT Enabled"): + jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = decode_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("facebook/mbart-large-cc25", from_pt=True) + # FlaxMBartForSequenceClassification expects eos token in input_ids + input_ids = np.ones((1, 1)) * model.config.eos_token_id + outputs = model(input_ids) + self.assertIsNotNone(outputs) + + +@require_flax +@require_sentencepiece +@require_tokenizers +class FlaxMBartModelIntegrationTest(unittest.TestCase): + src_text = [ + " UN Chief Says There Is No Military Solution in Syria", + ] + expected_text = [ + "Şeful ONU declară că nu există o soluţie militară în Siria", + ] + model_name = "facebook/mbart-large-en-ro" + + @cached_property + def tokenizer(self): + return AutoTokenizer.from_pretrained(self.model_name) + + @cached_property + def model(self): + model = FlaxMBartForConditionalGeneration.from_pretrained(self.model_name, from_pt=True) + return model + + def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs): + generated_words = self.translate_src_text(**tokenizer_kwargs) + self.assertListEqual(self.expected_text, generated_words) + + def translate_src_text(self, **tokenizer_kwargs): + model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="np") + generated_ids = self.model.generate( + model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"], + early_stopping=True, + num_beams=2, + ).sequences + generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + return generated_words + + @slow + def test_batch_generation_en_ro(self): + self._assert_generated_batch_equal_expected()