Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full rework of the TF input/output embeddings and bias resizing #9193

Merged
merged 53 commits into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
326e1d4
Start rework resizing
jplu Dec 16, 2020
369419e
Rework bias/decoder resizing
jplu Dec 16, 2020
d97acd8
Full resizing rework
jplu Dec 16, 2020
ac0a2dc
Full resizing rework
jplu Dec 17, 2020
1dcb6b1
Start to update the models with the new approach
jplu Dec 17, 2020
ebf5faf
Finish to update the models
jplu Dec 18, 2020
5ee3b5f
Update all the tests
jplu Dec 18, 2020
0210659
Update the template
jplu Dec 18, 2020
4d361bd
Fix tests
jplu Dec 18, 2020
07ce9f2
Fix tests
jplu Dec 18, 2020
405d600
Test a new approach
jplu Dec 18, 2020
a696958
Refactoring
jplu Dec 18, 2020
396c2a0
Refactoring
jplu Dec 18, 2020
ea79994
Refactoring
jplu Dec 18, 2020
ddec821
New rework
jplu Dec 21, 2020
b84c3e6
Rework BART
jplu Dec 21, 2020
b582a0a
Rework bert+blenderbot
jplu Dec 21, 2020
31dca3a
Rework CTRL
jplu Dec 21, 2020
e80167c
Rework Distilbert
jplu Dec 21, 2020
eb617cd
Rework DPR
jplu Dec 21, 2020
2b2c532
Rework Electra
jplu Dec 21, 2020
8a1894a
Rework Flaubert
jplu Dec 21, 2020
1bb9acd
Rework Funnel
jplu Dec 21, 2020
d85aa49
Rework GPT2
jplu Dec 21, 2020
2f31e56
Rework Longformer
jplu Dec 21, 2020
0bcd08c
Rework Lxmert
jplu Dec 21, 2020
ce0833e
Rework marian+mbart
jplu Dec 21, 2020
69da314
Rework mobilebert
jplu Dec 21, 2020
0774304
Rework mpnet
jplu Dec 21, 2020
4286e0d
Rework openai
jplu Dec 21, 2020
b8f2181
Rework pegasus
jplu Dec 21, 2020
044e63f
Rework Roberta
jplu Dec 21, 2020
7082cfe
Rework T5
jplu Dec 21, 2020
9baf654
Rework xlm+xlnet
jplu Dec 21, 2020
f8e6f1b
Rework template
jplu Dec 21, 2020
98201a3
Fix TFT5EncoderOnly + DPRs
jplu Dec 21, 2020
d1165b1
Restore previous methods
jplu Dec 21, 2020
b7071f9
Fix Funnel
jplu Dec 21, 2020
6795e73
Fix CTRL and TransforXL
jplu Dec 21, 2020
09e7602
Apply style
jplu Dec 21, 2020
008ca77
Apply Sylvain's comments
jplu Dec 24, 2020
59086de
Restore a test in DPR
jplu Jan 5, 2021
fbeb6c8
Address the comments
jplu Jan 6, 2021
47a07aa
Fix bug
jplu Jan 7, 2021
4927da0
Apply style
jplu Jan 7, 2021
af9ee3c
remove unused import
jplu Jan 7, 2021
969e7db
Fix test
jplu Jan 7, 2021
8197422
Forgot a method
jplu Jan 7, 2021
951f899
missing test
jplu Jan 7, 2021
1f1dcef
Trigger CI
jplu Jan 7, 2021
22414d9
naming update
jplu Jan 7, 2021
5137cb3
Rebase
jplu Jan 7, 2021
48fa8c1
Trigger CI
jplu Jan 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
390 changes: 270 additions & 120 deletions src/transformers/modeling_tf_utils.py

Large diffs are not rendered by default.

80 changes: 20 additions & 60 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,21 @@ def build(self, input_shape):

super().build(input_shape)

def get_output_embeddings(self):
return self.decoder

def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]

def get_bias(self):
return {"bias": self.bias, "decoder_bias": self.decoder_bias}

def set_bias(self, value):
self.bias = value["bias"]
self.decoder_bias = value["decoder_bias"]
self.vocab_size = shape_list(value["bias"])[0]

def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
Expand Down Expand Up @@ -505,10 +520,7 @@ def get_input_embeddings(self):

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]

def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.embeddings.vocab_size = shape_list(value)[0]

def _prune_heads(self, heads_to_prune):
"""
Expand Down Expand Up @@ -834,34 +846,8 @@ def __init__(self, config, *inputs, **kwargs):
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")

def get_output_embeddings(self):
return self.albert.embeddings

def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)

# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)

init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)

self.predictions.decoder_bias.assign(init_decoder_bias)
def get_lm_head(self):
return self.predictions

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -979,34 +965,8 @@ def __init__(self, config, *inputs, **kwargs):
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")

def get_output_embeddings(self):
return self.albert.embeddings

def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)

# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)

init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)

self.predictions.decoder_bias.assign(init_decoder_bias)
def get_lm_head(self):
return self.predictions

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
77 changes: 46 additions & 31 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,29 @@ def dummy_inputs(self):
}
return dummy_inputs

def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)

return base_model.shared

def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)

try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value

base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]

with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass

embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)

@tf.function(
input_signature=[
{
Expand Down Expand Up @@ -634,6 +657,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
else None
)

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

def call(
self,
input_ids=None,
Expand Down Expand Up @@ -791,6 +817,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

def call(
self,
input_ids=None,
Expand Down Expand Up @@ -1009,6 +1038,9 @@ def __init__(self, config: BartConfig, *inputs, **kwargs):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")

def get_encoder(self):
return self.encoder

def get_decoder(self):
return self.decoder

Expand Down Expand Up @@ -1134,15 +1166,6 @@ def serving_output(self, output):
encoder_attentions=enc_attns,
)

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, value):
self.shared = value

def get_output_embeddings(self):
return self.shared


@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.",
Expand All @@ -1166,22 +1189,20 @@ def __init__(self, config, *inputs, **kwargs):
def get_decoder(self):
return self.model.decoder

def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)

# BART is a special case where the bias has two dimensions
# and not named just `bias`
if new_num_tokens is not None:
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens)
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight(
shape=(1, new_num_tokens),
initializer="zeros",
trainable=False,
name="final_logits_bias",
)
self.final_logits_bias.assign(init_bias)
def get_encoder(self):
return self.model.encoder

def get_output_embeddings(self):
return self.get_input_embeddings()

def set_output_embeddings(self, value):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
self.set_input_embeddings(value)

def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}

def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1356,12 +1377,6 @@ def adjust_logits_during_generation(self, logits, cur_len, max_length):
else:
return logits

def get_output_embeddings(self):
return self.model.shared

def get_encoder(self):
return self.model.encoder

def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
Expand Down
35 changes: 22 additions & 13 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
""" TF 2.0 BERT model. """

import warnings
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -526,6 +527,20 @@ def build(self, input_shape):

super().build(input_shape)

def get_output_embeddings(self):
return self.input_embeddings

def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]

def get_bias(self):
return {"bias": self.bias}

def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]

def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear")
Expand Down Expand Up @@ -582,7 +597,7 @@ def get_input_embeddings(self):

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]

def _prune_heads(self, heads_to_prune):
"""
Expand Down Expand Up @@ -918,13 +933,11 @@ def __init__(self, config, *inputs, **kwargs):
self.nsp = TFBertNSPHead(config, name="nsp___cls")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
return self.bert.embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -1044,13 +1057,11 @@ def __init__(self, config, *inputs, **kwargs):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
return self.bert.embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -1153,13 +1164,11 @@ def __init__(self, config, *inputs, **kwargs):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
return self.bert.embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@add_code_sample_docstrings(
Expand Down
27 changes: 19 additions & 8 deletions src/transformers/models/ctrl/modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
""" TF 2.0 CTRL model."""

import warnings

import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -242,10 +244,7 @@ def get_input_embeddings(self):

def set_input_embeddings(self, value):
self.w.weight = value
self.w.vocab_size = value.shape[0]

def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.w.vocab_size = shape_list(value)[0]

def _prune_heads(self, heads_to_prune):
"""
Expand Down Expand Up @@ -620,6 +619,20 @@ def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)

def get_output_embeddings(self):
return self.input_embeddings

def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]

def get_bias(self):
return {"bias": self.bias}

def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]

def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
Expand All @@ -640,13 +653,11 @@ def __init__(self, config, *inputs, **kwargs):

self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")

def get_output_embeddings(self):
return self.lm_head.input_embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name

def prepare_inputs_for_generation(self, inputs, past, **kwargs):
Expand Down
Loading