Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5 Encoder for Feature Extraction #8717

Merged
merged 38 commits into from
Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
35f502b
Add T5 Encoder class for feature extraction
agemagician Nov 22, 2020
b211acd
fix T5 encoder add_start_docstrings indent
agemagician Nov 22, 2020
b21313d
update init with T5 encoder
agemagician Nov 22, 2020
6b744d0
update init with TFT5ModelEncoder
agemagician Nov 22, 2020
b1c0127
remove TFT5ModelEncoder
agemagician Nov 22, 2020
67d25b7
change T5ModelEncoder order in init
agemagician Nov 22, 2020
3519db6
add T5ModelEncoder to transformers init
agemagician Nov 22, 2020
5d4db69
clean T5ModelEncoder
agemagician Nov 22, 2020
ceb83f2
update init with TFT5ModelEncoder
agemagician Nov 22, 2020
adcfb55
add TFModelEncoder for Tensorflow
agemagician Nov 22, 2020
725b225
update init with TFT5ModelEncoder
agemagician Nov 22, 2020
b846556
Update src/transformers/models/t5/modeling_t5.py
agemagician Nov 24, 2020
493a91c
remove encoder_outputs
agemagician Nov 24, 2020
d3ef598
Authorize missing decoder keys
agemagician Nov 24, 2020
9922f51
remove unnecessary input parameters
agemagician Nov 24, 2020
aefed96
remove use_cache
agemagician Nov 24, 2020
81754a5
add doctoring for T5 encoder
agemagician Nov 24, 2020
d527699
change return_dict to dot access
agemagician Nov 24, 2020
0f7fc54
add T5_ENCODER_INPUTS_DOCSTRING for TF T5
agemagician Nov 24, 2020
134dd64
change TFT5Encoder output type to BaseModelOutput
agemagician Nov 24, 2020
bf26228
remove unnecessary parameters for TFT5Encoder
agemagician Nov 24, 2020
6f66a35
remove unnecessary if statement
agemagician Nov 24, 2020
0629549
add import BaseModelOutput
agemagician Nov 24, 2020
2d53021
fix BaseModelOutput typo to TFBaseModelOutput
agemagician Nov 24, 2020
36efefb
update T5 doc with T5ModelEncoder
agemagician Nov 24, 2020
831fef6
add T5ModelEncoder to tests
agemagician Nov 24, 2020
e356b6a
finish pytorch
patrickvonplaten Nov 27, 2020
f74bb85
merge
patrickvonplaten Nov 27, 2020
d15ade6
finish docs and mt5
patrickvonplaten Nov 27, 2020
4299356
add mtf to init
patrickvonplaten Nov 27, 2020
1b257ac
fix init
patrickvonplaten Nov 27, 2020
fbf8a5a
remove n_positions
patrickvonplaten Nov 27, 2020
e43ae9f
finish PR
patrickvonplaten Nov 27, 2020
815a76a
Update src/transformers/models/mt5/modeling_mt5.py
patrickvonplaten Nov 29, 2020
606b93e
Update src/transformers/models/t5/modeling_t5.py
patrickvonplaten Nov 29, 2020
36e0d23
Update src/transformers/models/t5/modeling_tf_t5.py
patrickvonplaten Nov 29, 2020
486283b
Update src/transformers/models/mt5/modeling_tf_mt5.py
patrickvonplaten Nov 29, 2020
a6a84d9
make style
patrickvonplaten Nov 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/model_doc/mt5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ MT5ForConditionalGeneration
:members:


MT5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.MT5EncoderModel
:members:


TFMT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -51,3 +58,10 @@ TFMT5ForConditionalGeneration

.. autoclass:: transformers.TFMT5ForConditionalGeneration
:members:


TFMT5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFMT5EncoderModel
:members:
11 changes: 11 additions & 0 deletions docs/source/model_doc/t5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ T5ForConditionalGeneration
.. autoclass:: transformers.T5ForConditionalGeneration
:members: forward, parallelize, deparallelize

T5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.T5EncoderModel
:members: forward

TFT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -121,3 +126,9 @@ TFT5ForConditionalGeneration

.. autoclass:: transformers.TFT5ForConditionalGeneration
:members: call

TFT5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFT5EncoderModel
:members: call
6 changes: 4 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@
MobileBertPreTrainedModel,
load_tf_weights_in_mobilebert,
)
from .models.mt5 import MT5ForConditionalGeneration, MT5Model
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
from .models.openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel,
Expand Down Expand Up @@ -559,6 +559,7 @@
)
from .models.t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel,
T5ForConditionalGeneration,
T5Model,
T5PreTrainedModel,
Expand Down Expand Up @@ -801,7 +802,7 @@
TFMobileBertModel,
TFMobileBertPreTrainedModel,
)
from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from .models.mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
from .models.openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel,
Expand All @@ -824,6 +825,7 @@
)
from .models.t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
TFT5ForConditionalGeneration,
TFT5Model,
TFT5PreTrainedModel,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mt5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


if is_torch_available():
from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model

if is_tf_available():
from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
31 changes: 26 additions & 5 deletions src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" PyTorch mT5 model. """

from ...utils import logging
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
from ..t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model
from .configuration_mt5 import MT5Config


Expand Down Expand Up @@ -73,11 +73,32 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
_keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]


class MT5EncoderModel(T5EncoderModel):
r"""
This class overrides :class:`~transformers.T5EncoderModel`. Please check the superclass for the appropriate
documentation alongside usage examples.

Examples::
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
>>> from transformers import MT5EncoderModel, T5Tokenizer
>>> model = MT5EncoderModel.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> input_ids = tokenizer(article, return_tensors="pt").input_ids
>>> outputs = model(input_ids)
>>> hidden_state = outputs.last_hidden_state
"""

model_type = "mt5"
config_class = MT5Config
_keys_to_ignore_on_load_missing = [
r"encoder\.embed_tokens\.weight",
]
_keys_to_ignore_on_save = [
r"encoder\.embed_tokens\.weight",
]
21 changes: 20 additions & 1 deletion src/transformers/models/mt5/modeling_tf_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
""" Tensorflow mT5 model. """

from ...utils import logging
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from ..t5.modeling_tf_t5 import TFT5EncoderModel, TFT5ForConditionalGeneration, TFT5Model
from .configuration_mt5 import MT5Config


Expand Down Expand Up @@ -64,3 +64,22 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):

model_type = "mt5"
config_class = MT5Config


class TFMT5EncoderModel(TFT5EncoderModel):
r"""
This class overrides :class:`~transformers.TFT5EncoderModel`. Please check the superclass for the appropriate
documentation alongside usage examples.

Examples::
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
>>> from transformers import TFMT5EncoderModel, T5Tokenizer
>>> model = TFMT5EncoderModel.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> input_ids = tokenizer(article, return_tensors="tf").input_ids
>>> outputs = model(input_ids)
>>> hidden_state = outputs.last_hidden_state
"""

model_type = "mt5"
config_class = MT5Config
2 changes: 2 additions & 0 deletions src/transformers/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
if is_torch_available():
from .modeling_t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel,
T5ForConditionalGeneration,
T5Model,
T5PreTrainedModel,
Expand All @@ -24,6 +25,7 @@
if is_tf_available():
from .modeling_tf_t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel,
TFT5ForConditionalGeneration,
TFT5Model,
TFT5PreTrainedModel,
Expand Down
117 changes: 116 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def _init_weights(self, module):
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, T5LayerNorm):
module.weight.data.fill_(factor * 1.0)
elif isinstance(module, (T5Model, T5ForConditionalGeneration)):
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
Expand Down Expand Up @@ -1082,6 +1082,45 @@ def forward(
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""

T5_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
should be able to pad the inputs on both the right and the left.

Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
detail.

To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
<./t5.html#training>`__.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""


@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
Expand Down Expand Up @@ -1518,3 +1557,79 @@ def _reorder_cache(self, past, beam_idx):

reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past


@add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-states" "without any specific head on top.",
T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
]

def __init__(self, config: T5Config):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)

encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)

self.init_weights()

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

def get_encoder(self):
return self.encoder

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:

Example::
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
>>> from transformers import T5Tokenizer, T5EncoderModel
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5EncoderModel.from_pretrained('t5-small')
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

return encoder_outputs
Loading