-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TFVisionEncoderDecoderModel #14148
Changes from all commits
462857a
1876bb6
5171144
95ba425
1294bbd
ac0bd79
5a95231
26eece4
d982490
28a00e5
bc45ee9
13d4cfd
4884258
969dd37
b4ebc91
6b3f00c
d1c2b4e
01a1583
3252f5b
00c514e
7f8de14
1406fe2
f8a4f75
7c40e48
557184b
d123794
e4b92f7
cc77071
ab43aba
1fb215e
e5d235b
b325586
79f34e5
e0d9c9e
d852154
68b32f2
0896ccc
c725f8e
53c4c6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
|
||||
import inspect | ||||
from dataclasses import dataclass | ||||
from typing import Optional, Tuple, Union | ||||
|
||||
|
@@ -628,14 +629,18 @@ def generate( | |||
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) | ||||
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" | ||||
|
||||
# This block corresponds to the following line in `generation_utils`: | ||||
# "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))" | ||||
# with the following differences: | ||||
# 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF. | ||||
# 2. There is no shape checking in PT. | ||||
# In both PT/TF, if `input_ids` is `None`, we try to create it as it is for a text model. | ||||
if input_ids is None: | ||||
assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( | ||||
"you should either supply a context to complete as `input_ids` input " | ||||
"or a `bos_token_id` (integer >= 0) as a first token to start the generation." | ||||
) | ||||
input_ids = tf.fill((batch_size, 1), bos_token_id) | ||||
else: | ||||
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)." | ||||
|
||||
# not allow to duplicate outputs when greedy decoding | ||||
if do_sample is False: | ||||
|
@@ -691,21 +696,29 @@ def generate( | |||
# get encoder and store encoder outputs | ||||
encoder = self.get_encoder() | ||||
|
||||
encoder_outputs = encoder( | ||||
input_ids, | ||||
attention_mask=attention_mask, | ||||
output_attentions=output_attentions, | ||||
output_hidden_states=output_hidden_states, | ||||
return_dict=return_dict_in_generate, | ||||
) | ||||
encoder_kwargs = { | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fine with me! |
||||
"attention_mask": attention_mask, | ||||
"output_attentions": output_attentions, | ||||
"output_hidden_states": output_hidden_states, | ||||
"return_dict": return_dict_in_generate, | ||||
} | ||||
|
||||
# vision models don't use `attention_mask`. | ||||
ydshieh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
signature = dict(inspect.signature(encoder.call).parameters) | ||||
if "attention_mask" not in signature: | ||||
encoder_kwargs.pop("attention_mask") | ||||
|
||||
encoder_outputs = encoder(input_ids, **encoder_kwargs) | ||||
if return_dict_in_generate: | ||||
if output_attentions: | ||||
model_kwargs["encoder_attentions"] = encoder_outputs.attentions | ||||
if output_hidden_states: | ||||
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states | ||||
|
||||
# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs. | ||||
# (vision inputs might occur when the model is an encoder-decoder model) | ||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1 | ||||
if num_return_sequences > 1 or num_beams > 1: | ||||
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a bit hacky - vision inputs should also work with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just as remark: The code inside this block treats For vision inputs,
It's not clear to me if a standalone vision model will need to call |
||||
input_ids_len = shape_list(input_ids)[-1] | ||||
input_ids = tf.broadcast_to( | ||||
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,10 +148,10 @@ | |
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) | ||
class TFEncoderDecoderModel(TFPreTrainedModel): | ||
r""" | ||
[`TFEncoderDecoder`] is a generic model class that will be instantiated as a transformer architecture with one of | ||
the base model classes of the library as encoder and another one as decoder when created with the | ||
:meth*~transformers.TFAutoModel.from_pretrained* class method for the encoder and | ||
:meth*~transformers.TFAutoModelForCausalLM.from_pretrained* class method for the decoder. | ||
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one | ||
of the base model classes of the library as encoder and another one as decoder when created with the | ||
[`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class | ||
method for the decoder. | ||
""" | ||
config_class = EncoderDecoderConfig | ||
base_model_prefix = "encoder_decoder" | ||
|
@@ -233,13 +233,6 @@ def dummy_inputs(self): | |
# Add `decoder_input_ids` because `self.decoder` requires it. | ||
input_ids = tf.constant(DUMMY_INPUTS) | ||
dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids} | ||
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized | ||
if self.config.add_cross_attention: | ||
batch_size, seq_len = input_ids.shape | ||
shape = (batch_size, seq_len) + (self.config.hidden_size,) | ||
h = tf.random.uniform(shape=shape) | ||
dummy["encoder_hidden_states"] = h | ||
|
||
Comment on lines
-236
to
-242
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this part removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
(I don't remember why I did that before, probably in some intermediate commits, it was required) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good to me! |
||
return dummy | ||
|
||
def get_encoder(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole function is in dire need of a refactor - I'll try to tackle this with @Rocketknight1 this month. Good for me the way it is now though