From f03866fd5c714846a1792d38fdb6b9e42ce7985b Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 11 Aug 2022 17:32:11 +0100 Subject: [PATCH] Return the permuted hidden states if return_dict=True (#18578) --- src/transformers/models/convnext/modeling_tf_convnext.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py index 405aeff6e0bdd5..0be2d291923812 100644 --- a/src/transformers/models/convnext/modeling_tf_convnext.py +++ b/src/transformers/models/convnext/modeling_tf_convnext.py @@ -330,7 +330,8 @@ def call( hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]]) if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] + hidden_states = hidden_states if output_hidden_states else () + return (last_hidden_state, pooled_output) + hidden_states return TFBaseModelOutputWithPooling( last_hidden_state=last_hidden_state,