-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TFVisionEncoderDecoderModel #14148
Add TFVisionEncoderDecoderModel #14148
Conversation
8e395af
to
ea19bdf
Compare
5e9ffb1
to
5b08958
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this model!
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized | ||
if self.config.add_cross_attention: | ||
batch_size, seq_len = input_ids.shape | ||
shape = (batch_size, seq_len) + (self.config.hidden_size,) | ||
h = tf.random.uniform(shape=shape) | ||
dummy["encoder_hidden_states"] = h | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this part removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFEncoderDecoderModel.call()
doesn't have encoder_hidden_states
parameter, but encoder_outputs
.
Moreover, encoder_hidden_states
is always passed to the decoder with encoder_hidden_states = encoder_outputs[0]
. Therefore, there is no need to add encoder_hidden_states
in dummy_inputs
.
(I don't remember why I did that before, probably in some intermediate commits, it was required)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
@sgugger, thank you for your review, I have addressed them. I am impressed by your ability to spot the I feel that sometimes I dive a bit deeper, and found: this one won't be reformatted by
but the following one will work well.
The difference is the ending comma after the last argument. Is this a bug? I can open an issue if so. |
c675811
to
e235759
Compare
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict_in_generate, | ||
) | ||
encoder_kwargs = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine with me!
# Expand input ids if num_beams > 1 or num_return_sequences > 1 | ||
if num_return_sequences > 1 or num_beams > 1: | ||
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a bit hacky - vision inputs should also work with num_beams > 1
no? But ok for now until we do the big generate
refactor @Rocketknight1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just as remark: The code inside this block treats input_ids
as text-only inputs, for example shape_list(input_ids)[-1]
assumes that the last dimension is the sequence dim.
For vision inputs, generate
will be called only if it is a vision model as encoder in an encoder-decoder model (?). In this case, it is more the decoder_input_ids
to be processed. And I think this is done in the next block
if self.config.is_encoder_decoder: |
It's not clear to me if a standalone vision model will need to call generate()
. Maybe @NielsRogge can share some insights here (ImageGPT
?).
@@ -153,7 +153,7 @@ | |||
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) | |||
class TFEncoderDecoderModel(TFPreTrainedModel): | |||
r""" | |||
:class:`~transformers.TFEncoderDecoder` is a generic model class that will be instantiated as a transformer | |||
:class:`~transformers.TFEncoderDecoderModel` is a generic model class that will be instantiated as a transformer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great addition @ydshieh! Thanks a lot for the contribution.
From my side it would be great if we could:
- remove all pytorch specific changes that are not needed to get the TF version working (I'll tackle this in a future PR :-) )
- add one slow test that ensures that the model works correctly
Thanks a bunch!
|
# Conflicts: # docs/source/model_doc/auto.rst # src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
Hi, @patrickvonplaten @Rocketknight1 @NielsRogge I removed the changes on PT code. I also added
here
This PR is ready for review when you have the time :-) |
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | ||
|
||
@classmethod | ||
def from_encoder_decoder_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably add this part:
transformers/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
Line 402 in 8f6373c
if kwargs_encoder.get("from_pt", None): |
here as well no? Otherwise it'll be difficult to load TF from PT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, great catch! For some reason I missed it in this PR. I will check again this PR against TFEncoderDecoderModel
code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten I corrected this part, and also updated the corresponding test (also forgot in the previous commits).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
@@ -628,14 +629,18 @@ def generate( | |||
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) | |||
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" | |||
|
|||
# This block corresponds to the following line in `generation_utils`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole function is in dire need of a refactor - I'll try to tackle this with @Rocketknight1 this month. Good for me the way it is now though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks more or less ready to be merged to me:
- I think we should add the loading
from_pt
hack as well tofrom_encoder_decoder(...)
here - I think we can delete the file: docs/source/model_doc/auto.rst no?
Yes. I added the (empty) file to commit by mistake during git rebase/merge. |
|
||
max_diff = np.max(np.abs(logits_tf_2.numpy() - logits_tf.numpy())) | ||
self.assertAlmostEqual(max_diff, 0.0, places=3) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a # TensorFlow => PyTorch
block, which did nothing (and if did, it would fail since we can't use from_pretrained
along with from_pt
or from_tf
in this composite model.)
I removed it - need to remove the corresponding part in test_modeling_tf_encoder_decoder.py
too in another PR.
@NielsRogge @Rocketknight1 - could you guys take a look here as well? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me now! I let @sgugger take a final look here
Provide for sequence to sequence training to the decoder. Indices can be obtained using | ||
[`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for | ||
details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PyTorch model (VisionEncoderDecoderModel
) automatically creates the decoder_input_ids
by shifting the labels
. Is this not the case for the TF one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. I started this work before the new change (about decoder_input_ids
) was done in VisionEncoderDecoderModel
, and didn't follow it.
Would it be possible to leave it to another PR - I can make it, but prefer in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #14469
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #14139 (for info)
""" | ||
config_class = VisionEncoderDecoderConfig | ||
base_model_prefix = "vision_encoder_decoder" | ||
load_weight_prefix = "tf_vision_encoder_decoder_model_1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does the _1 come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally looked TFRagModel
as reference for implementing the TF composite model, and saw
transformers/tests/test_modeling_tf_rag.py
Line 973 in ac224bb
load_weight_prefix = "tf_rag_model_1" |
I copied it and it works well (fixing the problems I had at that time), and I didn't think this part in more details.
I just spent some time on checking this again - and in fact, we can use tf_vision_encoder_decoder_model
here, and "tf_encoder_decoder_model
for TFEncoderDecoderModel
.
Change this might break some user models - but since this is kind new models, and not popular yet, maybe it is worth the change. cc @patrickvonplaten , @Rocketknight1 , @sgugger, @LysandreJik for their thoughts on this.
(For tf_rag_model_1
, maybe there is particular reason to make something work though.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed to "tf_vision_encoder_decoder_model" - better not to continue with the strange _1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few typos in the docstrings, but LGTM once it's fixed!
Thanks for all your work on this @ydshieh !
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Applied @sgugger review suggestions. Failed tests are unrelated. |
Thanks again for all your work on this! |
What does this PR do?
To make Vision-Encoder-Text-Decoder family complete by adding
TFVisionEncoderDecoderModel
.To complete this PR, it requires to wait #13778 being merged to master (then rebase)
(And if we want to include a real integration test using the recent image-captioning ViT + GPT2 model, need to wait #14038 too)