From 10d614df43d26f9e6bedb0e47bd2cb6f00240621 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 8 Sep 2022 09:37:49 +0000 Subject: [PATCH 1/5] tmp commit --- src/transformers/modeling_tf_utils.py | 6 ++ .../models/bart/modeling_tf_bart.py | 68 +++++-------------- tests/models/bart/test_modeling_tf_bart.py | 4 +- 3 files changed, 26 insertions(+), 52 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 484417f7ad33c5..a0a180827a27fa 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -887,6 +887,12 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, # If not, make the value to None saved_weight_value = saved_weights.get(symbolic_weight_name, None) + # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's + # `model.shared/embeddings:0` are stored as `model.shared/weights:0`) + if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"): + symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0" + saved_weight_value = saved_weights.get(symbolic_weight_name, None) + # Add the updated name to the final list for computing missing/unexpected values symbolic_weights_names.add(symbolic_weight_name) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index c15d0ae50451ae..6402456b87b61c 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -35,8 +35,6 @@ TFCausalLanguageModelingLoss, TFModelInputType, TFPreTrainedModel, - TFSharedEmbeddings, - TFWrappedEmbeddings, keras_serializable, unpack_inputs, ) @@ -113,7 +111,7 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): return (one_cst - expanded_mask) * LARGE_NEGATIVE -class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings): +class TFBartLearnedPositionalEmbedding(tf.keras.layers.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ @@ -136,7 +134,7 @@ def call( position_ids = tf.range(seq_len, delta=1, name="range") position_ids += past_key_values_length - return super().call(position_ids + self.offset) + return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32)) class TFBartAttention(tf.keras.layers.Layer): @@ -667,7 +665,7 @@ class TFBartEncoder(tf.keras.layers.Layer): config: BartConfig """ - def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs): + def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs): super().__init__(**kwargs) self.config = config self.dropout = tf.keras.layers.Dropout(config.dropout) @@ -685,12 +683,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - @unpack_inputs def call( self, @@ -750,7 +742,8 @@ def call( raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + with tf.name_scope(self.embed_tokens.name + "/"): + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape) hidden_states = inputs_embeds + embed_pos @@ -820,7 +813,7 @@ class TFBartDecoder(tf.keras.layers.Layer): embed_tokens: output embedding """ - def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings] = None, **kwargs): + def __init__(self, config: BartConfig, embed_tokens: Optional[tf.keras.layers.Embedding] = None, **kwargs): super().__init__(**kwargs) self.config = config self.padding_idx = config.pad_token_id @@ -837,12 +830,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings self.dropout = tf.keras.layers.Dropout(config.dropout) - def get_embed_tokens(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - @unpack_inputs def call( self, @@ -943,7 +930,8 @@ def call( positions = self.embed_positions(input_shape, position_ids=position_ids) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + with tf.name_scope(self.embed_tokens.name + "/"): + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale hidden_states = inputs_embeds @@ -1038,36 +1026,19 @@ class TFBartMainLayer(tf.keras.layers.Layer): def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs): super().__init__(**kwargs) self.config = config - self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") - - # set tf scope correctly - if load_weight_prefix is None: - load_weight_prefix = "model.shared" - - with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name: - pass + load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix + self.shared = tf.keras.layers.Embedding(config.vocab_size, config.d_model, name=load_weight_prefix) - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - embed_tokens.vocab_size = self.shared.vocab_size - embed_tokens.hidden_size = self.shared.hidden_size - - self.encoder = TFBartEncoder(config, embed_tokens, name="encoder") - self.decoder = TFBartDecoder(config, embed_tokens, name="decoder") + self.encoder = TFBartEncoder(config, self.shared, name="encoder") + self.decoder = TFBartDecoder(config, self.shared, name="decoder") def get_input_embeddings(self): return self.shared def set_input_embeddings(self, new_embeddings): - self.shared.weight = new_embeddings - self.shared.vocab_size = self.shared.weight.shape[0] - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - self.encoder.set_embed_tokens(embed_tokens) - self.decoder.set_embed_tokens(embed_tokens) + self.shared = new_embeddings + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared @unpack_inputs def call( @@ -1273,11 +1244,7 @@ def call(self, x): BART_START_DOCSTRING, ) class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_unexpected = [ - r"model.encoder.embed_tokens.weight", - r"model.decoder.embed_tokens.weight", - ] - + _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"] _requires_load_weight_prefix = True def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs): @@ -1374,7 +1341,8 @@ def call( return_dict=return_dict, training=training, ) - lm_logits = self.model.shared(outputs[0], mode="linear") + # The output layer ("lm_head" in pytorch) is a dense layer whose weights are tied to the input embeddings. + lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) lm_logits = self.bias_layer(lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index 5e5c5ee592a119..1b4664a5f29102 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -635,7 +635,7 @@ def xsum_1_1_model(self): def test_xsum_1_1_generation(self): model = self.xsum_1_1_model - assert model.model.decoder.embed_tokens._layer == model.model.shared + assert model.model.decoder.embed_tokens == model.model.shared ARTICLE = ( "The Palestinian Authority officially became the 123rd member of the International Criminal Court on" " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" @@ -685,7 +685,7 @@ def test_xsum_1_1_generation(self): def test_xsum_1_1_xla_generation(self): # same test as above, but with `no_repeat_ngram_size=0` (not compatible with XLA) and XLA comparison enabled model = self.xsum_1_1_model - assert model.model.decoder.embed_tokens._layer == model.model.shared + assert model.model.decoder.embed_tokens == model.model.shared ARTICLE = ( "The Palestinian Authority officially became the 123rd member of the International Criminal Court on" " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" From 01b403f5ef55043d254e8ca6e184543bd2dc93d5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 8 Sep 2022 13:57:21 +0000 Subject: [PATCH 2/5] enable embedding resizing --- src/transformers/modeling_tf_utils.py | 101 +++++++++++++++++- .../models/bart/modeling_tf_bart.py | 9 +- .../models/mbart/modeling_tf_mbart.py | 2 +- tests/models/bart/test_modeling_tf_bart.py | 63 ----------- tests/test_modeling_tf_common.py | 32 ++---- 5 files changed, 115 insertions(+), 92 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index a0a180827a27fa..19d0621e32f685 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1700,7 +1700,9 @@ def get_lm_head(self) -> tf.keras.layers.Layer: """ return None - def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None + ) -> Union[tf.keras.layers.Embedding, tf.Variable]: """ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. @@ -1710,11 +1712,17 @@ def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: new_num_tokens (`int`, *optional*): The number of new tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just - returns a pointer to the input tokens `tf.Variable` module of the model without doing anything. + returns a pointer to the input tokens without doing anything. Return: - `tf.Variable`: Pointer to the input tokens Embeddings Module of the model. + `tf.Variable` or `tf.keras.layers.Embedding`: Pointer to the input tokens of the model. """ + # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor + + # Run the new code path if the model has a keras embeddings layer + if isinstance(self.get_input_embeddings(), tf.keras.layers.Embedding): + return self._v2_resized_token_embeddings(new_num_tokens) + if new_num_tokens is None or new_num_tokens == self.config.vocab_size: return self._get_word_embedding_weight(self.get_input_embeddings()) @@ -1725,7 +1733,32 @@ def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: return model_embeds + def _v2_resized_token_embeddings(self, new_num_tokens: Optional[int] = None) -> tf.keras.layers.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens without doing anything. + + Return: + `tf.keras.layers.Embedding`: Pointer to the input tokens of the model. + """ + if new_num_tokens is None or new_num_tokens == self.config.vocab_size: + return self.get_input_embeddings() + + model_embeds = self._v2_resize_token_embeddings(new_num_tokens) + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + + return model_embeds + def _get_word_embedding_weight(model, embedding_layer): + # TODO (joao): flagged for delection due to embeddings refactor + # If the variable holds the weights themselves, return them if isinstance(embedding_layer, tf.Tensor): return embedding_layer @@ -1755,6 +1788,7 @@ def _get_word_embedding_weight(model, embedding_layer): return None def _resize_token_embeddings(self, new_num_tokens): + # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings()) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) @@ -1776,6 +1810,27 @@ def _resize_token_embeddings(self, new_num_tokens): return self.get_input_embeddings() + def _v2_resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens) + self.set_input_embeddings(new_embeddings) + + # If word embeddings are not tied, make sure that lm head bias is resized as well + if self.get_bias() is not None: + old_lm_head_bias = self.get_bias() + new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) + self.set_bias(new_lm_head_bias) + + # If word embeddings are not tied, make sure that lm head decoder is resized as well. + tied_weights = self.get_input_embeddings() == self.get_output_embeddings() + if self.get_output_embeddings() is not None and not tied_weights: + old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) + # TODO (joao): this one probably needs a v2 version with other models + new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) + self.set_output_embeddings(new_lm_head_decoder) + + return self.get_input_embeddings() + def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens): """ Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. @@ -1885,6 +1940,7 @@ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Var `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is `None` """ + # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor old_embedding_dim = shape_list(old_embeddings)[1] init_range = getattr(self.config, "initializer_range", 0.02) embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens) @@ -1900,6 +1956,42 @@ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Var return new_embeddings + def _v2_get_resized_embeddings( + self, old_embeddings: tf.keras.layers.Embedding, new_num_tokens: int + ) -> tf.keras.layers.Embedding: + """ + Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. + + Args: + old_embeddings (`tf.keras.layers.Embedding`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Return: + `tf.keras.layers.Embedding`: Resized Embedding layer. + """ + # Get a new (initialized) embeddings layer + init_range = getattr(self.config, "initializer_range", 0.02) + new_embeddings = tf.keras.layers.Embedding( + input_dim=new_num_tokens, + output_dim=old_embeddings.output_dim, + embeddings_initializer=get_initializer(init_range), + name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0" + ) + new_embeddings(tf.constant([[0]])) + + # Copy the old embeddings to the new embeddings + if old_embeddings.input_dim >= new_num_tokens: + init_embeddings = old_embeddings.embeddings[:new_num_tokens] + else: + init_embeddings = tf.concat( + [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0 + ) + new_embeddings.embeddings.assign(init_embeddings) + return new_embeddings + def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. @@ -2632,6 +2724,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer): kwargs: Additional keyword arguments passed along to the `__init__` of `tf.keras.layers.Layer`. """ + # TODO (joao): flagged for delection due to embeddings refactor def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs): super().__init__(**kwargs) @@ -2848,6 +2941,8 @@ class TFWrappedEmbeddings: saving/storing the correct weights """ + # TODO (joao): flagged for delection due to embeddings refactor + def __init__(self, layer, abs_scope_name=None): self._layer = layer self._abs_scope_name = abs_scope_name diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 6402456b87b61c..690b420d3f4c74 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1244,7 +1244,7 @@ def call(self, x): BART_START_DOCSTRING, ) class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss): - _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"] + _keys_to_ignore_on_load_missing = [r"final_logits_bias"] _requires_load_weight_prefix = True def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs): @@ -1270,10 +1270,10 @@ def set_output_embeddings(self, value): self.set_input_embeddings(value) def get_bias(self): - return {"final_logits_bias": self.final_logits_bias} + return {"final_logits_bias": self.bias_layer.bias} def set_bias(self, value): - self.final_logits_bias = value["final_logits_bias"] + self.bias_layer.bias = value["final_logits_bias"] @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1341,7 +1341,8 @@ def call( return_dict=return_dict, training=training, ) - # The output layer ("lm_head" in pytorch) is a dense layer whose weights are tied to the input embeddings. + # TODO (joao): the line below is for models with tied embeddings. The previous TFBart had tied embeddings. + # The PT Bart does not have tied embeddings. Untie the weights while keeping loading retrocompatibility. lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True) lm_logits = self.bias_layer(lm_logits) masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 47bad2e21eb272..57b816e27afadf 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -137,7 +137,7 @@ def call( position_ids = tf.range(seq_len, delta=1, name="range") position_ids += past_key_values_length - return super().call(position_ids + self.offset) + return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32)) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index 1b4664a5f29102..db06c84e0f5b86 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -230,69 +230,6 @@ def test_model_common_attributes(self): name = model.get_bias() assert name is None - def test_resize_token_embeddings(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - def _get_word_embedding_weight(model, embedding_layer): - if hasattr(embedding_layer, "weight"): - return embedding_layer.weight - else: - # Here we build the word embeddings weights if not exists. - # And then we retry to get the attribute once built. - model(model.dummy_inputs) - if hasattr(embedding_layer, "weight"): - return embedding_layer.weight - else: - return None - - for model_class in self.all_model_classes: - for size in [config.vocab_size - 10, config.vocab_size + 10, None]: - # build the embeddings - model = model_class(config=config) - old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) - old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) - old_final_logits_bias = model.get_bias() - - # reshape the embeddings - model.resize_token_embeddings(size) - new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) - new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) - new_final_logits_bias = model.get_bias() - - # check that the resized embeddings size matches the desired size. - assert_size = size if size is not None else config.vocab_size - - self.assertEqual(new_input_embeddings.shape[0], assert_size) - - # check that weights remain the same after resizing - models_equal = True - for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): - if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: - models_equal = False - self.assertTrue(models_equal) - - if old_output_embeddings is not None and new_output_embeddings is not None: - self.assertEqual(new_output_embeddings.shape[0], assert_size) - - models_equal = True - for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): - if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: - models_equal = False - self.assertTrue(models_equal) - - if old_final_logits_bias is not None and new_final_logits_bias is not None: - old_final_logits_bias = old_final_logits_bias["final_logits_bias"] - new_final_logits_bias = new_final_logits_bias["final_logits_bias"] - self.assertEqual(new_final_logits_bias.shape[0], 1) - self.assertEqual(new_final_logits_bias.shape[1], assert_size) - - models_equal = True - for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): - for p1, p2 in zip(old, new): - if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: - models_equal = False - self.assertTrue(models_equal) - @tooslow def test_saved_model_creation(self): pass diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 0ef457c03523eb..4de6abf157e96a 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1118,30 +1118,20 @@ def prepare_numpy_arrays(inputs_dict): self.assert_outputs_same(output_for_dict_input, output_for_kw_input) def test_resize_token_embeddings(self): + # TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on + # tf.keras.layers.Embedding + if not self.test_resize_embeddings: return config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() def _get_word_embedding_weight(model, embedding_layer): - embeds = getattr(embedding_layer, "weight", None) - if embeds is not None: - return embeds - - embeds = getattr(embedding_layer, "decoder", None) - if embeds is not None: - return embeds - - model(model.dummy_inputs) - - embeds = getattr(embedding_layer, "weight", None) - if embeds is not None: - return embeds - - embeds = getattr(embedding_layer, "decoder", None) - if embeds is not None: - return embeds - - return None + if isinstance(embedding_layer, tf.keras.layers.Embedding): + # builds the embeddings layer + model(model.dummy_inputs) + return embedding_layer.embeddings + else: + return model._get_word_embedding_weight(embedding_layer) for model_class in self.all_model_classes: for size in [config.vocab_size - 10, config.vocab_size + 10, None]: @@ -1169,10 +1159,10 @@ def _get_word_embedding_weight(model, embedding_layer): if old_bias is not None and new_bias is not None: for old_weight, new_weight in zip(old_bias.values(), new_bias.values()): - self.assertEqual(new_weight.shape[0], assert_size) + self.assertEqual(new_weight.shape[-1], assert_size) models_equal = True - for p1, p2 in zip(old_weight.value(), new_weight.value()): + for p1, p2 in zip(old_weight.value()[0], new_weight.value()[0]): if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: models_equal = False self.assertTrue(models_equal) From 01b7a9fc0af3d7a7a1071fb23c80ac9a6d722465 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 8 Sep 2022 15:52:37 +0000 Subject: [PATCH 3/5] handle row and column vectors --- tests/test_modeling_tf_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 4de6abf157e96a..384d1598ab6392 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1162,7 +1162,7 @@ def _get_word_embedding_weight(model, embedding_layer): self.assertEqual(new_weight.shape[-1], assert_size) models_equal = True - for p1, p2 in zip(old_weight.value()[0], new_weight.value()[0]): + for p1, p2 in zip(tf.squeeze(old_weight), tf.squeeze(new_weight)): if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: models_equal = False self.assertTrue(models_equal) From 8b651e05d757f4f22c9975e53c2184079088409d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 8 Sep 2022 16:31:18 +0000 Subject: [PATCH 4/5] correct cast (Matt's suggestion) --- src/transformers/models/bart/modeling_tf_bart.py | 2 +- src/transformers/models/mbart/modeling_tf_mbart.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 690b420d3f4c74..3be1831a963e77 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -134,7 +134,7 @@ def call( position_ids = tf.range(seq_len, delta=1, name="range") position_ids += past_key_values_length - return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32)) + return super().call(position_ids + tf.constant(self.offset, dtype=position_ids.dtype)) class TFBartAttention(tf.keras.layers.Layer): diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 57b816e27afadf..d3cca2c9ad2f73 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -137,7 +137,7 @@ def call( position_ids = tf.range(seq_len, delta=1, name="range") position_ids += past_key_values_length - return super().call(position_ids + tf.constant(self.offset, dtype=tf.int32)) + return super().call(position_ids + tf.constant(self.offset, dtype=position_ids.dtype)) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart From 8ae6751910caaab4d76dc35d81a03c4a00081dfe Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 9 Sep 2022 13:12:56 +0000 Subject: [PATCH 5/5] fix cast if input is an integer --- src/transformers/models/bart/modeling_tf_bart.py | 3 ++- src/transformers/models/mbart/modeling_tf_mbart.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 3be1831a963e77..17c0ce7a710502 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -134,7 +134,8 @@ def call( position_ids = tf.range(seq_len, delta=1, name="range") position_ids += past_key_values_length - return super().call(position_ids + tf.constant(self.offset, dtype=position_ids.dtype)) + offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32 + return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype)) class TFBartAttention(tf.keras.layers.Layer): diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index d3cca2c9ad2f73..3f6a44fcf4d096 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -137,7 +137,8 @@ def call( position_ids = tf.range(seq_len, delta=1, name="range") position_ids += past_key_values_length - return super().call(position_ids + tf.constant(self.offset, dtype=position_ids.dtype)) + offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32 + return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype)) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->MBart