From c2f8eaf6bcb3a8a3f25289726bdfec880cf178cc Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 30 Mar 2022 17:12:27 +0100 Subject: [PATCH] TF: unpack inputs on Convbert, GPTJ, LED, and templates (#16491) * Add unpack_inputs to remaining models * remove stray use of inputs in the templates; fix tf.debugging of attn masks --- .../models/bart/modeling_tf_bart.py | 4 +- .../models/convbert/modeling_tf_convbert.py | 254 ++----- .../models/gptj/modeling_tf_gptj.py | 111 ++- .../models/led/modeling_tf_led.py | 366 ++++------ .../models/mbart/modeling_tf_mbart.py | 4 +- .../modeling_tf_speech_to_text.py | 4 +- ...tf_{{cookiecutter.lowercase_modelname}}.py | 656 ++++++------------ 7 files changed, 466 insertions(+), 933 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 9599bfe1d1e..106a87c043c 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -943,12 +943,12 @@ def call( # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - for attn_mask in [head_mask, cross_attn_head_mask]: + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: if attn_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(attn_mask)[0], len(self.layers), - message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", + message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", ) for idx, decoder_layer in enumerate(self.layers): diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py index 61a4f5d69e4..f167325527b 100644 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ b/src/transformers/models/convbert/modeling_tf_convbert.py @@ -39,8 +39,8 @@ TFSequenceSummary, TFTokenClassificationLoss, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import ( @@ -568,6 +568,7 @@ def get_head_mask(self, head_mask): return head_mask + @unpack_inputs def call( self, input_ids=None, @@ -582,60 +583,36 @@ def call( training=False, **kwargs, ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(input_shape, 1) + if attention_mask is None: + attention_mask = tf.fill(input_shape, 1) - if inputs["token_type_ids"] is None: - inputs["token_type_ids"] = tf.fill(input_shape, 0) + if token_type_ids is None: + token_type_ids = tf.fill(input_shape, 0) - hidden_states = self.embeddings( - inputs["input_ids"], - inputs["position_ids"], - inputs["token_type_ids"], - inputs["inputs_embeds"], - training=inputs["training"], - ) - extended_attention_mask = self.get_extended_attention_mask( - inputs["attention_mask"], input_shape, hidden_states.dtype - ) - inputs["head_mask"] = self.get_head_mask(inputs["head_mask"]) + hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype) + head_mask = self.get_head_mask(head_mask) if hasattr(self, "embeddings_project"): - hidden_states = self.embeddings_project(hidden_states, training=inputs["training"]) + hidden_states = self.embeddings_project(hidden_states, training=training) hidden_states = self.encoder( hidden_states, extended_attention_mask, - inputs["head_mask"], - inputs["output_attentions"], - inputs["output_hidden_states"], - inputs["return_dict"], - training=inputs["training"], + head_mask, + output_attentions, + output_hidden_states, + return_dict, + training=training, ) return hidden_states @@ -754,6 +731,7 @@ def __init__(self, config, *inputs, **kwargs): self.convbert = TFConvBertMainLayer(config, name="convbert") + @unpack_inputs @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -775,9 +753,7 @@ def call( training=False, **kwargs, ): - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.convbert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -788,19 +764,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - outputs = self.convbert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -886,6 +849,7 @@ def get_lm_head(self): def get_prefix_bias_name(self): return self.name + "/" + self.generator_lm_head.name + @unpack_inputs @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -914,9 +878,7 @@ def call( config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - inputs = input_processing( - func=self.call, - config=self.config, + generator_hidden_states = self.convbert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -926,28 +888,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - generator_hidden_states = self.convbert( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) generator_sequence_output = generator_hidden_states[0] - prediction_scores = self.generator_predictions(generator_sequence_output, training=inputs["training"]) - prediction_scores = self.generator_lm_head(prediction_scores, training=inputs["training"]) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores) + prediction_scores = self.generator_predictions(generator_sequence_output, training=training) + prediction_scores = self.generator_lm_head(prediction_scores, training=training) + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) - if not inputs["return_dict"]: + if not return_dict: output = (prediction_scores,) + generator_hidden_states[1:] return ((loss,) + output) if loss is not None else output @@ -1010,6 +958,7 @@ def __init__(self, config, *inputs, **kwargs): self.convbert = TFConvBertMainLayer(config, name="convbert") self.classifier = TFConvBertClassificationHead(config, name="classifier") + @unpack_inputs @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1038,10 +987,8 @@ def call( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + outputs = self.convbert( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1050,26 +997,12 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, ) - outputs = self.convbert( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - logits = self.classifier(outputs[0], training=inputs["training"]) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) + logits = self.classifier(outputs[0], training=training) + loss = None if labels is None else self.hf_compute_loss(labels, logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1117,6 +1050,7 @@ def dummy_inputs(self): """ return {"input_ids": tf.convert_to_tensor(MULTIPLE_CHOICE_DUMMY_INPUTS)} + @unpack_inputs @add_start_docstrings_to_model_forward( CONVBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") ) @@ -1146,43 +1080,20 @@ def call( Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - training=training, - kwargs_call=kwargs, - ) - - if inputs["input_ids"] is not None: - num_choices = shape_list(inputs["input_ids"])[1] - seq_length = shape_list(inputs["input_ids"])[2] + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] else: - num_choices = shape_list(inputs["inputs_embeds"])[1] - seq_length = shape_list(inputs["inputs_embeds"])[2] + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] - flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None - flat_attention_mask = ( - tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None - ) - flat_token_type_ids = ( - tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None - ) - flat_position_ids = ( - tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None - ) + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_inputs_embeds = ( - tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) - if inputs["inputs_embeds"] is not None + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None else None ) outputs = self.convbert( @@ -1190,19 +1101,19 @@ def call( flat_attention_mask, flat_token_type_ids, flat_position_ids, - inputs["head_mask"], + head_mask, flat_inputs_embeds, - inputs["output_attentions"], - inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, ) - logits = self.sequence_summary(outputs[0], training=inputs["training"]) + logits = self.sequence_summary(outputs[0], training=training) logits = self.classifier(logits) reshaped_logits = tf.reshape(logits, (-1, num_choices)) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) - if not inputs["return_dict"]: + if not return_dict: output = (reshaped_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1256,6 +1167,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) + @unpack_inputs @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1282,10 +1194,8 @@ def call( labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + outputs = self.convbert( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1294,28 +1204,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.convbert( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output, training=inputs["training"]) + sequence_output = self.dropout(sequence_output, training=training) logits = self.classifier(sequence_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) + loss = None if labels is None else self.hf_compute_loss(labels, logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1350,6 +1246,7 @@ def __init__(self, config, *inputs, **kwargs): config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) + @unpack_inputs @add_start_docstrings_to_model_forward(CONVBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1383,10 +1280,8 @@ def call( Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + outputs = self.convbert( + input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, @@ -1395,22 +1290,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - start_positions=start_positions, - end_positions=end_positions, training=training, - kwargs_call=kwargs, - ) - outputs = self.convbert( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) @@ -1419,12 +1299,12 @@ def call( end_logits = tf.squeeze(end_logits, axis=-1) loss = None - if inputs["start_positions"] is not None and inputs["end_positions"] is not None: - labels = {"start_position": inputs["start_positions"]} - labels["end_position"] = inputs["end_positions"] + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions loss = self.hf_compute_loss(labels, (start_logits, end_logits)) - if not inputs["return_dict"]: + if not return_dict: output = (start_logits, end_logits) + outputs[1:] return ((loss,) + output) if loss is not None else output diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index e6d9a4f308a..ce5c5d78e5a 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -40,7 +40,6 @@ TFSequenceClassificationLoss, TFSharedEmbeddings, get_initializer, - input_processing, keras_serializable, unpack_inputs, ) @@ -376,6 +375,7 @@ def _prune_heads(self, heads_to_prune): """ raise NotImplementedError + @unpack_inputs def call( self, input_ids=None, @@ -392,53 +392,34 @@ def call( training=False, **kwargs, ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs["past_key_values"] is None: + if past_key_values is None: past_length = 0 - inputs["past_key_values"] = [None] * len(self.h) + past_key_values = [None] * len(self.h) else: - past_length = shape_list(inputs["past_key_values"][0][0])[-2] + past_length = shape_list(past_key_values[0][0])[-2] - if inputs["position_ids"] is None: - inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) + if position_ids is None: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) - if inputs["attention_mask"] is not None: + if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(inputs["attention_mask"]) - inputs["attention_mask"] = tf.reshape( - inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) - ) + attention_mask_shape = shape_list(attention_mask) + attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -446,78 +427,74 @@ def call( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. one_cst = tf.constant(1.0) - inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) - inputs["attention_mask"] = tf.multiply( - tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0) - ) + attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) + attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.num_hidden_layers + head_mask = [None] * self.num_hidden_layers # head_mask = tf.constant([0] * self.num_hidden_layers) - inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) + position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids, mode="embedding") - if inputs["token_type_ids"] is not None: - inputs["token_type_ids"] = tf.reshape( - inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] - ) - token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding") + if token_type_ids is not None: + token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) + token_type_embeds = self.wte(token_type_ids, mode="embedding") else: token_type_embeds = tf.constant(0.0) - token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) - hidden_states = inputs["inputs_embeds"] + token_type_embeds - hidden_states = self.drop(hidden_states, training=inputs["training"]) + token_type_embeds = tf.cast(token_type_embeds, dtype=inputs_embeds.dtype) + hidden_states = inputs_embeds + token_type_embeds + hidden_states = self.drop(hidden_states, training=training) output_shape = input_shape + [shape_list(hidden_states)[-1]] - presents = () if inputs["use_cache"] else None - all_attentions = () if inputs["output_attentions"] else None - all_hidden_states = () if inputs["output_hidden_states"] else None - for i, (block, layer_past) in enumerate(zip(self.h, inputs["past_key_values"])): - if inputs["output_hidden_states"]: + presents = () if use_cache else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) outputs = block( hidden_states, layer_past, - inputs["attention_mask"], - inputs["head_mask"][i], - inputs["use_cache"], - inputs["output_attentions"], - training=inputs["training"], + attention_mask, + head_mask[i], + use_cache, + output_attentions, + training=training, ) hidden_states = outputs[0] - if inputs["use_cache"]: + if use_cache: presents = presents + (outputs[1],) - if inputs["output_attentions"]: - all_attentions = all_attentions + (outputs[2 if inputs["use_cache"] else 1],) + if output_attentions: + all_attentions = all_attentions + (outputs[2 if use_cache else 1],) hidden_states = self.ln_f(hidden_states) hidden_states = tf.reshape(hidden_states, output_shape) # Add last hidden state - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if inputs["output_attentions"]: + if output_attentions: # let the number of heads free (-1) so we can extract attention even after head pruning attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - if not inputs["return_dict"]: + if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return TFBaseModelOutputWithPast( diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index cdd82c6c505..4519f5df980 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -30,8 +30,8 @@ TFSharedEmbeddings, TFWrappedEmbeddings, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import ( @@ -1654,6 +1654,7 @@ def get_embed_tokens(self): def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens + @unpack_inputs def call( self, input_ids=None, @@ -1703,95 +1704,74 @@ def call( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - global_attention_mask=global_attention_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + inputs_embeds = self.embed_tokens(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(input_shape, 1) + if attention_mask is None: + attention_mask = tf.fill(input_shape, 1) # merge `global_attention_mask` and `attention_mask` - if inputs["global_attention_mask"] is not None: - inputs["attention_mask"] = inputs["attention_mask"] * tf.cast( - (inputs["global_attention_mask"] + 1), dtype=inputs["attention_mask"].dtype - ) + if global_attention_mask is not None: + attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype) - ( - padding_len, - inputs["input_ids"], - inputs["attention_mask"], - inputs["inputs_embeds"], - ) = self._pad_to_window_size( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - inputs_embeds=inputs["inputs_embeds"], + (padding_len, input_ids, attention_mask, inputs_embeds,) = self._pad_to_window_size( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, pad_token_id=self.padding_idx, ) - input_shape = shape_list(inputs["attention_mask"]) + input_shape = shape_list(attention_mask) # is index masked or global attention - is_index_masked = tf.math.less(tf.cast(inputs["attention_mask"], tf.int8), 1) - is_index_global_attn = tf.math.greater(tf.cast(inputs["attention_mask"], tf.int8), 1) + is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1) + is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1) is_global_attn = tf.math.reduce_any(is_index_global_attn) embed_pos = self.embed_positions(input_shape) - hidden_states = inputs["inputs_embeds"] + embed_pos + hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # check attention mask and invert - if inputs["attention_mask"] is not None: + if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])[:, 0, 0, :] - inputs["attention_mask"] = inputs["attention_mask"][:, :, None, None] + attention_mask = _expand_mask(attention_mask)[:, 0, 0, :] + attention_mask = attention_mask[:, :, None, None] - encoder_states = () if inputs["output_hidden_states"] else None - all_attentions = all_global_attentions = () if inputs["output_attentions"] else None + encoder_states = () if output_hidden_states else None + all_attentions = all_global_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None and tf.executing_eagerly(): + if head_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], + shape_list(head_mask)[0], len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.", ) # encoder layers for idx, encoder_layer in enumerate(self.layers): - if inputs["output_hidden_states"]: + if output_hidden_states: hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len) encoder_states = encoder_states + (hidden_states_to_add,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer + if training and (dropout_probability < self.layerdrop): # skip the layer continue layer_outputs = encoder_layer( hidden_states=hidden_states, - attention_mask=inputs["attention_mask"], - layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, is_index_masked=is_index_masked, is_index_global_attn=is_index_global_attn, is_global_attn=is_global_attn, @@ -1799,7 +1779,7 @@ def call( hidden_states = layer_outputs[0] - if inputs["output_attentions"]: + if output_attentions: # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1) all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),) @@ -1811,17 +1791,17 @@ def call( hidden_states = self.compute_hidden_states(hidden_states, padding_len) # undo padding - if inputs["output_attentions"]: + if output_attentions: all_attentions = ( tuple([state[:, :, :-padding_len, :] for state in all_attentions]) if padding_len > 0 else all_attentions ) - if inputs["output_hidden_states"]: + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not inputs["return_dict"]: + if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return TFLEDEncoderBaseModelOutput( last_hidden_state=hidden_states, @@ -1915,6 +1895,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[TFSharedEmbeddings] def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens + @unpack_inputs def call( self, input_ids=None, @@ -1985,45 +1966,25 @@ def call( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - encoder_head_mask=encoder_head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - past_key_values_length = ( - shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0 - ) + past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 # embed positions positions = self.embed_positions(input_shape, past_key_values_length) - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs["inputs_embeds"] + hidden_states = inputs_embeds # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] if input_shape[-1] > 1: @@ -2033,17 +1994,15 @@ def call( tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is not None and input_shape[-1] > 1: - combined_attention_mask = combined_attention_mask + _expand_mask( - inputs["attention_mask"], tgt_len=input_shape[-1] - ) + if attention_mask is not None and input_shape[-1] > 1: + combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1]) - if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: + if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) hidden_states = self.layernorm_embedding(hidden_states + positions) - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # decoder layers all_hidden_states = () @@ -2052,54 +2011,52 @@ def call( present_key_values = () # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None and tf.executing_eagerly(): + if head_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], + shape_list(head_mask)[0], len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.", ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): + if training and (dropout_probability < self.layerdrop): continue - past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - encoder_layer_head_mask=inputs["encoder_head_mask"][idx] - if inputs["encoder_head_mask"] is not None - else None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None, past_key_value=past_key_value, ) - if inputs["use_cache"]: + if use_cache: present_key_values += (present_key_value,) - if inputs["output_attentions"]: + if output_attentions: all_self_attns += (layer_self_attn,) all_cross_attentions += (layer_cross_attn,) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) else: all_hidden_states = None - all_self_attns = all_self_attns if inputs["output_attentions"] else None - all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None + all_self_attns = all_self_attns if output_attentions else None + all_cross_attentions = all_cross_attentions if output_attentions else None - present_key_values = present_key_values if inputs["use_cache"] else None + present_key_values = present_key_values if use_cache else None - if not inputs["return_dict"]: + if not return_dict: return tuple( v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions] @@ -2149,6 +2106,7 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens) + @unpack_inputs def call( self, input_ids=None, @@ -2169,72 +2127,51 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - encoder_outputs=encoder_outputs, - global_attention_mask=global_attention_mask, + + if decoder_input_ids is None and decoder_inputs_embeds is None: + use_cache = False + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput): + encoder_outputs = TFLEDEncoderBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + encoder_head_mask=head_mask, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, + inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - - if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: - inputs["use_cache"] = False - - if inputs["encoder_outputs"] is None: - inputs["encoder_outputs"] = self.encoder( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - global_attention_mask=inputs["global_attention_mask"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True - elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFLEDEncoderBaseModelOutput): - inputs["encoder_outputs"] = TFLEDEncoderBaseModelOutput( - last_hidden_state=inputs["encoder_outputs"][0], - hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, - attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, - ) - # If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): - inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() + ) - decoder_outputs = self.decoder( - inputs["decoder_input_ids"], - attention_mask=inputs["decoder_attention_mask"], - encoder_hidden_states=inputs["encoder_outputs"][0], - encoder_attention_mask=inputs["attention_mask"], - head_mask=inputs["decoder_head_mask"], - encoder_head_mask=inputs["head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - - if not inputs["return_dict"]: - return decoder_outputs + inputs["encoder_outputs"] + if not return_dict: + return decoder_outputs + encoder_outputs return TFLEDSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, @@ -2242,10 +2179,10 @@ def call( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, - encoder_hidden_states=inputs["encoder_outputs"].hidden_states, - encoder_attentions=inputs["encoder_outputs"].attentions, - encoder_global_attentions=inputs["encoder_outputs"].global_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + encoder_global_attentions=encoder_outputs.global_attentions, ) @@ -2265,6 +2202,7 @@ def get_encoder(self): def get_decoder(self): return self.led.decoder + @unpack_inputs @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -2292,17 +2230,16 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, + + outputs = self.led( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -2311,25 +2248,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - outputs = self.led( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - decoder_attention_mask=inputs["decoder_attention_mask"], - encoder_outputs=inputs["encoder_outputs"], - global_attention_mask=inputs["global_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["inputs_embeds"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -2393,6 +2311,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, value): self.set_input_embeddings(value) + @unpack_inputs @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -2435,17 +2354,22 @@ def call( >>> # probs[5] is associated with the mask token ```""" - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + if labels is not None: + use_cache = False + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.led( + input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, encoder_outputs=encoder_outputs, global_attention_mask=global_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -2453,41 +2377,13 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - - if inputs["labels"] is not None: - inputs["use_cache"] = False - if inputs["decoder_input_ids"] is None: - inputs["decoder_input_ids"] = shift_tokens_right( - inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.led( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - decoder_attention_mask=inputs["decoder_attention_mask"], - encoder_outputs=inputs["encoder_outputs"], - global_attention_mask=inputs["global_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["inputs_embeds"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) lm_logits = self.led.shared(outputs[0], mode="linear") lm_logits = lm_logits + self.final_logits_bias - masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - if not inputs["return_dict"]: + if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return TFLEDSeq2SeqLMOutput( diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index f9f3014c67f..3f2ea655f45 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -965,12 +965,12 @@ def call( # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - for attn_mask in [head_mask, cross_attn_head_mask]: + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: if attn_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(attn_mask)[0], len(self.layers), - message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", + message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", ) for idx, decoder_layer in enumerate(self.layers): diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index 7bdf620c655..6c78ab1b58f 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -1060,12 +1060,12 @@ def call( # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they have to be disabled in other modes than eager. - for attn_mask in [head_mask, cross_attn_head_mask]: + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: if attn_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(attn_mask)[0], len(self.layers), - message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", + message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", ) for idx, decoder_layer in enumerate(self.layers): diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 0fd552e0290..da2a0a3828f 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -50,8 +50,8 @@ TFSequenceSummary, TFTokenClassificationLoss, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -636,6 +636,7 @@ class PreTrainedModel """ raise NotImplementedError + @unpack_inputs def call( self, input_ids: Optional[TFModelInputType] = None, @@ -654,59 +655,40 @@ def call( training: bool = False, **kwargs, ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) if not self.config.is_decoder: - inputs["use_cache"] = False + use_cache = False - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape - if inputs["past_key_values"] is None: + if past_key_values is None: past_key_values_length = 0 - inputs["past_key_values"] = [None] * len(self.encoder.layer) + past_key_values = [None] * len(self.encoder.layer) else: - past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2] + past_key_values_length = shape_list(past_key_values[0][0])[-2] - if inputs["attention_mask"] is None: - inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) - if inputs["token_type_ids"] is None: - inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0) + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) embedding_output = self.embeddings( - input_ids=inputs["input_ids"], - position_ids=inputs["position_ids"], - token_type_ids=inputs["token_type_ids"], - inputs_embeds=inputs["inputs_embeds"], + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, - training=inputs["training"], + training=training, ) # We create a 3D attention mask from a 2D tensor mask. @@ -714,7 +696,7 @@ def call( # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask_shape = shape_list(inputs["attention_mask"]) + attention_mask_shape = shape_list(attention_mask) mask_seq_length = seq_length + past_key_values_length # Copied from `modeling_tf_t5.py` @@ -727,18 +709,18 @@ def call( tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None], ) - causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype) - extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :] + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] attention_mask_shape = shape_list(extended_attention_mask) extended_attention_mask = tf.reshape( extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) ) - if inputs["past_key_values"][0] is not None: + if past_key_values[0] is not None: # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] else: extended_attention_mask = tf.reshape( - inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for @@ -752,18 +734,18 @@ def call( extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 - if self.is_decoder and inputs["encoder_attention_mask"] is not None: + if self.is_decoder and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - inputs["encoder_attention_mask"] = tf.cast( - inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype + encoder_attention_mask = tf.cast( + encoder_attention_mask, dtype=extended_attention_mask.dtype ) - num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"])) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) if num_dims_encoder_attention_mask == 3: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if num_dims_encoder_attention_mask == 2: - encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :] + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 @@ -779,28 +761,28 @@ def call( # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.config.num_hidden_layers + head_mask = [None] * self.config.num_hidden_layers encoder_outputs = self.encoder( hidden_states=embedding_output, attention_mask=extended_attention_mask, - head_mask=inputs["head_mask"], - encoder_hidden_states=inputs["encoder_hidden_states"], + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) sequence_output = encoder_outputs[0] - if not inputs["return_dict"]: + if not return_dict: return ( sequence_output, ) + encoder_outputs[1:] @@ -943,6 +925,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -988,9 +971,7 @@ def call( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). Set to `False` during training, `True` during generation """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.{{cookiecutter.lowercase_modelname}}( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1005,23 +986,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - outputs = self.{{cookiecutter.lowercase_modelname}}( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -1064,6 +1028,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, def get_lm_head(self) -> tf.keras.layers.Layer: return self.mlm.predictions + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1091,9 +1056,7 @@ def call( Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.{{cookiecutter.lowercase_modelname}}( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1103,29 +1066,15 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.{{cookiecutter.lowercase_modelname}}( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"]) + prediction_scores = self.mlm(sequence_output=sequence_output, training=training) loss = ( - None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores) + None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores) ) - if not inputs["return_dict"]: + if not return_dict: output = (prediction_scores,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1173,6 +1122,7 @@ def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, "use_cache": model_kwargs["use_cache"], } + @unpack_inputs @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, @@ -1220,9 +1170,7 @@ def call( labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.{{cookiecutter.lowercase_modelname}}( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1236,37 +1184,19 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.{{cookiecutter.lowercase_modelname}}( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - past_key_values=inputs["past_key_values"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - logits = self.mlm(sequence_output=sequence_output, training=inputs["training"]) + logits = self.mlm(sequence_output=sequence_output, training=training) loss = None - if inputs["labels"] is not None: + if labels is not None: # shift labels to the left and cut last logit token shifted_logits = logits[:, :-1] - labels = inputs["labels"][:, 1:] + labels = labels[:, 1:] loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1338,6 +1268,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") self.classifier = TF{{cookiecutter.camelcase_modelname}}ClassificationHead(config, name="classifier") + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1365,9 +1296,7 @@ def call( Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.{{cookiecutter.lowercase_modelname}}( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1377,26 +1306,12 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, ) - outputs = self.{{cookiecutter.lowercase_modelname}}( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - logits = self.classifier(hidden_states=outputs[0], training=inputs["training"]) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) - - if not inputs["return_dict"]: + logits = self.classifier(hidden_states=outputs[0], training=training) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1443,6 +1358,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]: """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1470,53 +1386,37 @@ def call( Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - labels=labels, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None: - num_choices = shape_list(inputs["input_ids"])[1] - seq_length = shape_list(inputs["input_ids"])[2] + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] else: - num_choices = shape_list(inputs["inputs_embeds"])[1] - seq_length = shape_list(inputs["inputs_embeds"])[2] + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] flat_input_ids = ( - tf.reshape(tensor=inputs["input_ids"], shape=(-1, seq_length)) if inputs["input_ids"] is not None else None + tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None ) flat_attention_mask = ( - tf.reshape(tensor=inputs["attention_mask"], shape=(-1, seq_length)) - if inputs["attention_mask"] is not None + tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) + if attention_mask is not None else None ) flat_token_type_ids = ( - tf.reshape(tensor=inputs["token_type_ids"], shape=(-1, seq_length)) - if inputs["token_type_ids"] is not None + tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) + if token_type_ids is not None else None ) flat_position_ids = ( - tf.reshape(tensor=inputs["position_ids"], shape=(-1, seq_length)) - if inputs["position_ids"] is not None + tf.reshape(tensor=position_ids, shape=(-1, seq_length)) + if position_ids is not None else None ) flat_inputs_embeds = ( tf.reshape( - tensor=inputs["inputs_embeds"], shape=(-1, seq_length, shape_list(inputs["inputs_embeds"])[3]) + tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]) ) - if inputs["inputs_embeds"] is not None + if inputs_embeds is not None else None ) outputs = self.{{cookiecutter.lowercase_modelname}}( @@ -1524,19 +1424,19 @@ def call( attention_mask=flat_attention_mask, token_type_ids=flat_token_type_ids, position_ids=flat_position_ids, - head_mask=inputs["head_mask"], + head_mask=head_mask, inputs_embeds=flat_inputs_embeds, - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, ) - logits = self.sequence_summary(inputs=outputs[0], training=inputs["training"]) + logits = self.sequence_summary(inputs=outputs[0], training=training) logits = self.classifier(inputs=logits) reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices)) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=reshaped_logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits) - if not inputs["return_dict"]: + if not return_dict: output = (reshaped_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1585,6 +1485,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1611,9 +1512,8 @@ def call( labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + + outputs = self.{{cookiecutter.lowercase_modelname}}( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1623,28 +1523,14 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - outputs = self.{{cookiecutter.lowercase_modelname}}( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] - sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"]) + sequence_output = self.dropout(inputs=sequence_output, training=training) logits = self.classifier(inputs=sequence_output) - loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output @@ -1680,6 +1566,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" ) + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -1713,9 +1600,7 @@ def call( Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ - inputs = input_processing( - func=self.call, - config=self.config, + outputs = self.{{cookiecutter.lowercase_modelname}}( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1725,22 +1610,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - start_positions=start_positions, - end_positions=end_positions, training=training, - kwargs_call=kwargs, - ) - outputs = self.{{cookiecutter.lowercase_modelname}}( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) sequence_output = outputs[0] logits = self.qa_outputs(inputs=sequence_output) @@ -1749,12 +1619,12 @@ def call( end_logits = tf.squeeze(input=end_logits, axis=-1) loss = None - if inputs["start_positions"] is not None and inputs["end_positions"] is not None: - labels = {"start_position": inputs["start_positions"]} - labels["end_position"] = inputs["end_positions"] + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits)) - if not inputs["return_dict"]: + if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((loss,) + output) if loss is not None else output @@ -1801,8 +1671,8 @@ def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAn TFPreTrainedModel, TFSharedEmbeddings, TFWrappedEmbeddings, - input_processing, keras_serializable, + unpack_inputs, ); from ...tf_utils import (shape_list, ) from ...utils import logging @@ -2381,6 +2251,7 @@ def get_embed_tokens(self): def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens + @unpack_inputs def call( self, input_ids=None, @@ -2435,78 +2306,65 @@ def call( Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape) - hidden_states = inputs["inputs_embeds"] + embed_pos + hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # check attention mask and invert - if inputs["attention_mask"] is not None: + if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - inputs["attention_mask"] = _expand_mask(inputs["attention_mask"]) + attention_mask = _expand_mask(attention_mask) - encoder_states = () if inputs["output_hidden_states"] else None - all_attentions = () if inputs["output_attentions"] else None + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - if inputs["head_mask"] is not None and tf.executing_eagerly(): + if head_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs["head_mask"])[0], + shape_list(head_mask)[0], len(self.layers), - message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", + message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(head_mask)[0]}.", ) # encoder layers for idx, encoder_layer in enumerate(self.layers): - if inputs["output_hidden_states"]: + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer + if training and (dropout_probability < self.layerdrop): # skip the layer continue hidden_states, attn = encoder_layer( hidden_states, - inputs["attention_mask"], - inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, + attention_mask, + head_mask[idx] if head_mask is not None else None, ) - if inputs["output_attentions"]: + if output_attentions: all_attentions += (attn,) - if inputs["output_hidden_states"]: + if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not inputs["return_dict"]: + if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return TFBaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions @@ -2547,6 +2405,7 @@ def get_embed_tokens(self): def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens + @unpack_inputs def call( self, input_ids=None, @@ -2632,111 +2491,93 @@ def call( Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). """ - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") past_key_values_length = ( - shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0 + shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0 ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) - hidden_states = inputs["inputs_embeds"] + hidden_states = inputs_embeds - inputs["attention_mask"], combined_attention_mask = self.compute_combined_attns_mask( - inputs, input_shape, past_key_values_length + attention_mask, combined_attention_mask = self.compute_combined_attns_mask( + input_ids, attention_mask, input_shape, past_key_values_length ) - if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: + if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1]) + encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1]) hidden_states = self.layernorm_embedding(hidden_states + positions) - hidden_states = self.dropout(hidden_states, training=inputs["training"]) + hidden_states = self.dropout(hidden_states, training=training) # decoder layers - all_hidden_states = () if inputs["output_hidden_states"] else None - all_self_attns = () if inputs["output_attentions"] else None - all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None - present_key_values = () if inputs["use_cache"] else None + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if (output_attentions and encoder_hidden_states is not None) else None + present_key_values = () if use_cache else None # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired # The tf.debugging asserts are not compliant with XLA then they # have to be disabled in other modes than eager. - for attn_mask in ["head_mask", "cross_attn_head_mask"]: - if inputs[attn_mask] is not None and tf.executing_eagerly(): + for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]: + if attn_mask is not None and tf.executing_eagerly(): tf.debugging.assert_equal( - shape_list(inputs[attn_mask])[0], + shape_list(attn_mask)[0], len(self.layers), - message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.", + message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.", ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = random.uniform(0, 1) - if inputs["training"] and (dropout_probability < self.layerdrop): + if training and (dropout_probability < self.layerdrop): continue - past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, - encoder_hidden_states=inputs["encoder_hidden_states"], - encoder_attention_mask=inputs["encoder_attention_mask"], - layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None, - cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx] - if inputs["cross_attn_head_mask"] is not None + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + cross_attn_layer_head_mask=cross_attn_head_mask[idx] + if cross_attn_head_mask is not None else None, past_key_value=past_key_value, ) - if inputs["use_cache"]: + if use_cache: present_key_values += (present_key_value,) - if inputs["output_attentions"]: + if output_attentions: all_self_attns += (layer_self_attn,) - if inputs["encoder_hidden_states"] is not None: + if encoder_hidden_states is not None: all_cross_attns += (layer_cross_attn,) - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states += (hidden_states,) - if not inputs["return_dict"]: + if not return_dict: return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns else: return TFBaseModelOutputWithPastAndCrossAttentions( @@ -2748,7 +2589,7 @@ def call( ) @tf.function - def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_length): + def compute_combined_attns_mask(self, input_ids, attention_mask, input_shape, past_key_values_length): # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: @@ -2758,9 +2599,9 @@ def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_lengt tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1: + if attention_mask is None and input_ids is not None and input_shape[-1] > 1: attention_mask = tf.cast( - tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype + tf.math.not_equal(input_ids, self.config.pad_token_id), input_ids.dtype ) attention_mask = tf.concat( [ @@ -2810,6 +2651,7 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_embed_tokens(embed_tokens) self.decoder.set_embed_tokens(embed_tokens) + @unpack_inputs def call( self, input_ids=None, @@ -2830,71 +2672,50 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, + + if decoder_input_ids is None and decoder_inputs_embeds is None: + use_cache = False + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput): + encoder_outputs = TFBaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False + elif not return_dict and not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + + decoder_outputs = self.decoder( + decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, + inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - - if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None: - inputs["use_cache"] = False - - if inputs["encoder_outputs"] is None: - inputs["encoder_outputs"] = self.encoder( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True - elif inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], TFBaseModelOutput): - inputs["encoder_outputs"] = TFBaseModelOutput( - last_hidden_state=inputs["encoder_outputs"][0], - hidden_states=inputs["encoder_outputs"][1] if len(inputs["encoder_outputs"]) > 1 else None, - attentions=inputs["encoder_outputs"][2] if len(inputs["encoder_outputs"]) > 2 else None, - ) - # If the user passed a TFBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False - elif not inputs["return_dict"] and not isinstance(inputs["encoder_outputs"], tuple): - inputs["encoder_outputs"] = inputs["encoder_outputs"].to_tuple() + ) - decoder_outputs = self.decoder( - inputs["decoder_input_ids"], - attention_mask=inputs["decoder_attention_mask"], - encoder_hidden_states=inputs["encoder_outputs"][0], - encoder_attention_mask=inputs["attention_mask"], - head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], - ) - - if not inputs["return_dict"]: - return decoder_outputs + inputs["encoder_outputs"] + if not return_dict: + return decoder_outputs + encoder_outputs return TFSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, @@ -2902,9 +2723,9 @@ def call( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, - encoder_hidden_states=inputs["encoder_outputs"].hidden_states, - encoder_attentions=inputs["encoder_outputs"].attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, ) @@ -2924,6 +2745,7 @@ def get_encoder(self): def get_decoder(self): return self.model.decoder + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -2951,9 +2773,8 @@ def call( training=False, **kwargs ): - inputs = input_processing( - func=self.call, - config=self.config, + + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -2970,26 +2791,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - - outputs = self.model( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - decoder_attention_mask=inputs["decoder_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - encoder_outputs=inputs["encoder_outputs"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["inputs_embeds"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -3053,6 +2854,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, value): self.set_input_embeddings(value) + @unpack_inputs @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -3093,17 +2895,23 @@ def call( >>> probs = tf.nn.softmax(logits[0]) >>> # probs[5] is associated with the mask token ```""" - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, + + if labels is not None: + use_cache = False + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, @@ -3111,41 +2919,13 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, - training=training, - kwargs_call=kwargs, - ) - - if inputs["labels"] is not None: - inputs["use_cache"] = False - if inputs["decoder_input_ids"] is None: - inputs["decoder_input_ids"] = shift_tokens_right( - inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id - ) - - outputs = self.model( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - decoder_input_ids=inputs["decoder_input_ids"], - encoder_outputs=inputs["encoder_outputs"], - decoder_attention_mask=inputs["decoder_attention_mask"], - head_mask=inputs["head_mask"], - decoder_head_mask=inputs["decoder_head_mask"], - cross_attn_head_mask=inputs["cross_attn_head_mask"], - past_key_values=inputs["past_key_values"], - inputs_embeds=inputs["inputs_embeds"], - decoder_inputs_embeds=inputs["decoder_inputs_embeds"], - use_cache=inputs["use_cache"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"] + training=training ) lm_logits = self.model.shared(outputs[0], mode="linear") lm_logits = lm_logits + self.final_logits_bias - masked_lm_loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], lm_logits) + masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits) - if not inputs["return_dict"]: + if not return_dict: output = (lm_logits,) + outputs[1:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return TFSeq2SeqLMOutput(