Skip to content

Commit

Permalink
Add T5 Encoder for Feature Extraction (huggingface#8717)
Browse files Browse the repository at this point in the history
* Add T5 Encoder class for feature extraction

* fix T5 encoder add_start_docstrings indent

* update init with T5 encoder

* update init with TFT5ModelEncoder

* remove TFT5ModelEncoder

* change T5ModelEncoder order in init

* add T5ModelEncoder to transformers init

* clean T5ModelEncoder

* update init with TFT5ModelEncoder

* add TFModelEncoder for Tensorflow

* update init with TFT5ModelEncoder

* Update src/transformers/models/t5/modeling_t5.py

change output from Seq2SeqModelOutput to BaseModelOutput

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* remove encoder_outputs

1. remove encoder_outputs from the function call.
2. remove the encoder_outputs If statement.
3. remove isinstance from return_dict.

* Authorize missing decoder keys

* remove unnecessary input parameters

remove pask_key_values and use_cache

* remove use_cache

remove use_cache from the forward method

* add doctoring for T5 encoder

add doctoring for T5 encoder with T5_ENCODER_INPUTS_DOCSTRING

* change return_dict to dot access

* add T5_ENCODER_INPUTS_DOCSTRING for TF T5

* change TFT5Encoder output type to BaseModelOutput

* remove unnecessary parameters for TFT5Encoder

* remove unnecessary if statement

* add import BaseModelOutput

* fix BaseModelOutput typo to TFBaseModelOutput

* update T5 doc with T5ModelEncoder

* add T5ModelEncoder to tests

* finish pytorch

* finish docs and mt5

* add mtf to init

* fix init

* remove n_positions

* finish PR

* Update src/transformers/models/mt5/modeling_mt5.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/t5/modeling_t5.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/t5/modeling_tf_t5.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update src/transformers/models/mt5/modeling_tf_mt5.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* make style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
  • Loading branch information
3 people authored and stas00 committed Dec 2, 2020
1 parent bd63610 commit 688b4ec
Show file tree
Hide file tree
Showing 15 changed files with 656 additions and 18 deletions.
14 changes: 14 additions & 0 deletions docs/source/model_doc/mt5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ MT5ForConditionalGeneration
:members:


MT5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.MT5EncoderModel
:members:


TFMT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -51,3 +58,10 @@ TFMT5ForConditionalGeneration

.. autoclass:: transformers.TFMT5ForConditionalGeneration
:members:


TFMT5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMT5EncoderModel
:members:
11 changes: 11 additions & 0 deletions docs/source/model_doc/t5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ T5ForConditionalGeneration
.. autoclass:: transformers.T5ForConditionalGeneration
:members: forward, parallelize, deparallelize

T5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.T5EncoderModel
:members: forward

TFT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -121,3 +126,9 @@ TFT5ForConditionalGeneration

.. autoclass:: transformers.TFT5ForConditionalGeneration
:members: call

TFT5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFT5EncoderModel
:members: call
6 changes: 4 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@
MobileBertPreTrainedModel,
load_tf_weights_in_mobilebert,
)
from .models.mt5 import MT5ForConditionalGeneration, MT5Model
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
from .models.openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel,
Expand Down Expand Up @@ -561,6 +561,7 @@
)
from .models.t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel,
T5ForConditionalGeneration,
T5Model,
T5PreTrainedModel,
Expand Down Expand Up @@ -803,7 +804,7 @@
TFMobileBertModel,
TFMobileBertPreTrainedModel,
)
from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from .models.mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
from .models.openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel,
Expand All @@ -826,6 +827,7 @@
)
from .models.t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
TFT5ForConditionalGeneration,
TFT5Model,
TFT5PreTrainedModel,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mt5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


if is_torch_available():
from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model

if is_tf_available():
from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
32 changes: 27 additions & 5 deletions src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" PyTorch mT5 model. """

from ...utils import logging
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
from ..t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model
from .configuration_mt5 import MT5Config


Expand Down Expand Up @@ -73,11 +73,33 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
_keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]


class MT5EncoderModel(T5EncoderModel):
r"""
This class overrides :class:`~transformers.T5EncoderModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
Examples::
>>> from transformers import MT5EncoderModel, T5Tokenizer
>>> model = MT5EncoderModel.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> input_ids = tokenizer(article, return_tensors="pt").input_ids
>>> outputs = model(input_ids)
>>> hidden_state = outputs.last_hidden_state
"""

model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
]
_keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight",
]
22 changes: 21 additions & 1 deletion src/transformers/models/mt5/modeling_tf_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" Tensorflow mT5 model. """

from ...utils import logging
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from ..t5.modeling_tf_t5 import TFT5EncoderModel, TFT5ForConditionalGeneration, TFT5Model
from .configuration_mt5 import MT5Config


Expand Down Expand Up @@ -64,3 +64,23 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):

model_type = "mt5"
config_class = MT5Config


class TFMT5EncoderModel(TFT5EncoderModel):
r"""
This class overrides :class:`~transformers.TFT5EncoderModel`. Please check the superclass for the appropriate
documentation alongside usage examples.
Examples::
>>> from transformers import TFMT5EncoderModel, T5Tokenizer
>>> model = TFMT5EncoderModel.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> input_ids = tokenizer(article, return_tensors="tf").input_ids
>>> outputs = model(input_ids)
>>> hidden_state = outputs.last_hidden_state
"""

model_type = "mt5"
config_class = MT5Config
2 changes: 2 additions & 0 deletions src/transformers/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if is_torch_available():
from .modeling_t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel,
T5ForConditionalGeneration,
T5Model,
T5PreTrainedModel,
Expand All @@ -24,6 +25,7 @@
if is_tf_available():
from .modeling_tf_t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
TFT5ForConditionalGeneration,
TFT5Model,
TFT5PreTrainedModel,
Expand Down
118 changes: 117 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def _init_weights(self, module):
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, T5LayerNorm):
module.weight.data.fill_(factor * 1.0)
elif isinstance(module, (T5Model, T5ForConditionalGeneration)):
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
Expand Down Expand Up @@ -1082,6 +1082,45 @@ def forward(
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""

T5_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
should be able to pad the inputs on both the right and the left.
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
detail.
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
<./t5.html#training>`__.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""


@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
Expand Down Expand Up @@ -1518,3 +1557,80 @@ def _reorder_cache(self, past, beam_idx):

reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past


@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states" "without any specific head on top.",
T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
]

def __init__(self, config: T5Config):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)

encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)

self.init_weights()

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

def get_encoder(self):
return self.encoder

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Example::
>>> from transformers import T5Tokenizer, T5EncoderModel
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5EncoderModel.from_pretrained('t5-small')
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

return encoder_outputs
Loading

0 comments on commit 688b4ec

Please sign in to comment.