From 3581100b94eda01da39cad0b4119fff53df00f5b Mon Sep 17 00:00:00 2001 From: ydshieh Date: Sat, 1 Jan 2022 17:01:04 +0100 Subject: [PATCH] Fix #14357 --- .../models/encoder_decoder/modeling_tf_encoder_decoder.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 8c725b05cc9..8b22e25ac52 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -547,6 +547,12 @@ def call( encoder_inputs = input_processing(**encoder_processing_inputs) + # Handle the case where the inputs are passed as a single dict which contains `labels`. + # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this + # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`). + if "labels" in encoder_inputs: + labels = encoder_inputs.pop("labels") + # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. if "decoder_input_ids" in encoder_inputs: decoder_input_ids = encoder_inputs.pop("decoder_input_ids")