Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Replace custom TF embeddings by Keras embeddings #18939

Merged
merged 5 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 104 additions & 3 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,12 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
# If not, make the value to None
saved_weight_value = saved_weights.get(symbolic_weight_name, None)

# Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
# `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0"
saved_weight_value = saved_weights.get(symbolic_weight_name, None)

# Add the updated name to the final list for computing missing/unexpected values
symbolic_weights_names.add(symbolic_weight_name)

Expand Down Expand Up @@ -1694,7 +1700,9 @@ def get_lm_head(self) -> tf.keras.layers.Layer:
"""
return None

def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None
) -> Union[tf.keras.layers.Embedding, tf.Variable]:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.

Expand All @@ -1704,11 +1712,17 @@ def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `tf.Variable` module of the model without doing anything.
returns a pointer to the input tokens without doing anything.

Return:
`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
`tf.Variable` or `tf.keras.layers.Embedding`: Pointer to the input tokens of the model.
"""
# TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor

# Run the new code path if the model has a keras embeddings layer
if isinstance(self.get_input_embeddings(), tf.keras.layers.Embedding):
return self._v2_resized_token_embeddings(new_num_tokens)

if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
return self._get_word_embedding_weight(self.get_input_embeddings())

Expand All @@ -1719,7 +1733,32 @@ def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:

return model_embeds

def _v2_resized_token_embeddings(self, new_num_tokens: Optional[int] = None) -> tf.keras.layers.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.

Arguments:
new_num_tokens (`int`, *optional*):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens without doing anything.

Return:
`tf.keras.layers.Embedding`: Pointer to the input tokens of the model.
"""
if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
return self.get_input_embeddings()

model_embeds = self._v2_resize_token_embeddings(new_num_tokens)

# Update base model and current model config
self.config.vocab_size = new_num_tokens

return model_embeds

def _get_word_embedding_weight(model, embedding_layer):
# TODO (joao): flagged for delection due to embeddings refactor

# If the variable holds the weights themselves, return them
if isinstance(embedding_layer, tf.Tensor):
return embedding_layer
Expand Down Expand Up @@ -1749,6 +1788,7 @@ def _get_word_embedding_weight(model, embedding_layer):
return None

def _resize_token_embeddings(self, new_num_tokens):
# TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)

Expand All @@ -1770,6 +1810,27 @@ def _resize_token_embeddings(self, new_num_tokens):

return self.get_input_embeddings()

def _v2_resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)

# If word embeddings are not tied, make sure that lm head bias is resized as well
if self.get_bias() is not None:
old_lm_head_bias = self.get_bias()
new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
self.set_bias(new_lm_head_bias)

# If word embeddings are not tied, make sure that lm head decoder is resized as well.
tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
if self.get_output_embeddings() is not None and not tied_weights:
old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
# TODO (joao): this one probably needs a v2 version with other models
new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
self.set_output_embeddings(new_lm_head_decoder)

return self.get_input_embeddings()

def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
"""
Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
Expand Down Expand Up @@ -1879,6 +1940,7 @@ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Var
`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
`None`
"""
# TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
old_embedding_dim = shape_list(old_embeddings)[1]
init_range = getattr(self.config, "initializer_range", 0.02)
embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
Expand All @@ -1894,6 +1956,42 @@ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Var

return new_embeddings

def _v2_get_resized_embeddings(
self, old_embeddings: tf.keras.layers.Embedding, new_num_tokens: int
) -> tf.keras.layers.Embedding:
"""
Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end.

Args:
old_embeddings (`tf.keras.layers.Embedding`):
Old embeddings to be resized.
new_num_tokens (`int`, *optional*):
New number of tokens in the embedding matrix.

Return:
`tf.keras.layers.Embedding`: Resized Embedding layer.
"""
# Get a new (initialized) embeddings layer
init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = tf.keras.layers.Embedding(
input_dim=new_num_tokens,
output_dim=old_embeddings.output_dim,
embeddings_initializer=get_initializer(init_range),
name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
)
new_embeddings(tf.constant([[0]]))

# Copy the old embeddings to the new embeddings
if old_embeddings.input_dim >= new_num_tokens:
init_embeddings = old_embeddings.embeddings[:new_num_tokens]
else:
init_embeddings = tf.concat(
[old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
)
new_embeddings.embeddings.assign(init_embeddings)
return new_embeddings

def prune_heads(self, heads_to_prune):
"""
Prunes heads of the base model.
Expand Down Expand Up @@ -2626,6 +2724,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
kwargs:
Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`.
"""
# TODO (joao): flagged for delection due to embeddings refactor

def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -2842,6 +2941,8 @@ class TFWrappedEmbeddings:
saving/storing the correct weights
"""

# TODO (joao): flagged for delection due to embeddings refactor

def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name
Expand Down
73 changes: 21 additions & 52 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
TFCausalLanguageModelingLoss,
TFModelInputType,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
keras_serializable,
unpack_inputs,
)
Expand Down Expand Up @@ -113,7 +111,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
return (one_cst - expanded_mask) * LARGE_NEGATIVE


class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
class TFBartLearnedPositionalEmbedding(tf.keras.layers.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
Expand All @@ -136,7 +134,7 @@ def call(
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length

return super().call(position_ids + self.offset)
return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32))
gante marked this conversation as resolved.
Show resolved Hide resolved


class TFBartAttention(tf.keras.layers.Layer):
Expand Down Expand Up @@ -667,7 +665,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
config: BartConfig
"""

def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.dropout = tf.keras.layers.Dropout(config.dropout)
Expand All @@ -685,12 +683,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")

def get_embed_tokens(self):
return self.embed_tokens

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

@unpack_inputs
def call(
self,
Expand Down Expand Up @@ -750,7 +742,8 @@ def call(
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
with tf.name_scope(self.embed_tokens.name + "/"):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
Expand Down Expand Up @@ -820,7 +813,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
embed_tokens: output embedding
"""

def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs):
def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.padding_idx = config.pad_token_id
Expand All @@ -837,12 +830,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings

self.dropout = tf.keras.layers.Dropout(config.dropout)

def get_embed_tokens(self):
return self.embed_tokens

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

@unpack_inputs
def call(
self,
Expand Down Expand Up @@ -943,7 +930,8 @@ def call(
positions = self.embed_positions(input_shape, position_ids=position_ids)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
with tf.name_scope(self.embed_tokens.name + "/"):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

hidden_states = inputs_embeds

Expand Down Expand Up @@ -1038,36 +1026,19 @@ class TFBartMainLayer(tf.keras.layers.Layer):
def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs)
self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")

# set tf scope correctly
if load_weight_prefix is None:
load_weight_prefix = "model.shared"

with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
pass
load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix
self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name=load_weight_prefix)

# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
embed_tokens.vocab_size = self.shared.vocab_size
embed_tokens.hidden_size = self.shared.hidden_size

self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
self.encoder = TFBartEncoder(config, self.shared, name="encoder")
self.decoder = TFBartDecoder(config, self.shared, name="decoder")

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
self.shared = new_embeddings
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

@unpack_inputs
def call(
Expand Down Expand Up @@ -1273,11 +1244,7 @@ def call(self, x):
BART_START_DOCSTRING,
)
class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
_keys_to_ignore_on_load_unexpected = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
]

_keys_to_ignore_on_load_missing = [r"final_logits_bias"]
_requires_load_weight_prefix = True

def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs):
Expand All @@ -1303,10 +1270,10 @@ def set_output_embeddings(self, value):
self.set_input_embeddings(value)

def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}
return {"final_logits_bias": self.bias_layer.bias}

def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]
self.bias_layer.bias = value["final_logits_bias"]

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1374,7 +1341,9 @@ def call(
return_dict=return_dict,
training=training,
)
lm_logits = self.model.shared(outputs[0], mode="linear")
# TODO (joao): the line below is for models with tied embeddings. The previous TFBart had tied embeddings.
# The PT Bart does not have tied embeddings. Untie the weights while keeping loading retrocompatibility.
lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
lm_logits = self.bias_layer(lm_logits)
masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/modeling_tf_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def call(
position_ids = tf.range(seq_len, delta=1, name="range")
position_ids += past_key_values_length

return super().call(position_ids + self.offset)
return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(make fix-copies)

gante marked this conversation as resolved.
Show resolved Hide resolved


# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart
Expand Down
Loading