Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFVisionEncoderDecoderModel #14148

Merged
merged 39 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
462857a
Start the work on TFVisionEncoderDecoderModel
ydshieh Oct 25, 2021
1876bb6
Expose TFVisionEncoderDecoderModel
ydshieh Oct 25, 2021
5171144
fix import
ydshieh Oct 25, 2021
95ba425
Add modeling_tf_vision_encoder_decoder to _ignore_modules in get_mode…
ydshieh Oct 26, 2021
1294bbd
reorder
ydshieh Nov 9, 2021
ac0bd79
Apply the fix for checkpoint loading as in #14016
ydshieh Nov 10, 2021
5a95231
remove attention_mask + fix VISION_DUMMY_INPUTS
ydshieh Nov 10, 2021
26eece4
A minimal change to make TF generate() work for vision models as enco…
ydshieh Nov 13, 2021
d982490
fix wrong condition: shape_list(input_ids) == 2
ydshieh Nov 13, 2021
28a00e5
add tests
ydshieh Nov 13, 2021
bc45ee9
use personal TFViTModel checkpoint (for now)
ydshieh Nov 14, 2021
13d4cfd
Add equivalence tests + projection layer
ydshieh Nov 14, 2021
4884258
style
ydshieh Nov 14, 2021
969dd37
make sure projection layer can run
ydshieh Nov 14, 2021
b4ebc91
Add examples
ydshieh Nov 14, 2021
6b3f00c
Apply suggestions from code review
ydshieh Nov 22, 2021
d1c2b4e
Clean comments (need to work on TODOs for PyTorch models)
ydshieh Nov 22, 2021
01a1583
Remove TF -> PT in check_pt_tf_equivalence for TFVisionEncoderDecoder…
ydshieh Dec 1, 2021
3252f5b
fixes
ydshieh Dec 2, 2021
00c514e
Revert changes in PT code.
ydshieh Dec 11, 2021
7f8de14
Update tests/test_modeling_tf_vision_encoder_decoder.py
ydshieh Dec 11, 2021
1406fe2
Add test_inference_coco_en for TF test
ydshieh Dec 11, 2021
f8a4f75
fix quality
ydshieh Dec 11, 2021
7c40e48
fix name
ydshieh Dec 12, 2021
557184b
build doc
ydshieh Dec 22, 2021
d123794
add main_input_name
ydshieh Dec 22, 2021
e4b92f7
Fix ckpt name in test
ydshieh Dec 22, 2021
cc77071
fix diff between master and this PR
ydshieh Dec 26, 2021
ab43aba
fix doc
ydshieh Dec 26, 2021
1fb215e
Merge branch 'master' into tf_vision_encoder_decoder
ydshieh Dec 31, 2021
e5d235b
fix style and quality
ydshieh Dec 31, 2021
b325586
fix missing doc
ydshieh Dec 31, 2021
79f34e5
fix labels handling
ydshieh Jan 1, 2022
e0d9c9e
Merge branch 'master' into tf_vision_encoder_decoder
ydshieh Jan 3, 2022
d852154
Delete auto.rst
ydshieh Jan 3, 2022
68b32f2
Add the changes done in #14016
ydshieh Jan 6, 2022
0896ccc
fix prefix
ydshieh Jan 9, 2022
c725f8e
Apply suggestions from code review
ydshieh Jan 10, 2022
53c4c6a
make style
ydshieh Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ Flax), PyTorch, and/or TensorFlow.
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
Expand Down
4 changes: 4 additions & 0 deletions docs/source/model_doc/auto.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its

[[autodoc]] TFAutoModelForQuestionAnswering

## TFAutoModelForVision2Seq

[[autodoc]] TFAutoModelForVision2Seq

## FlaxAutoModel

[[autodoc]] FlaxAutoModel
Expand Down
6 changes: 6 additions & 0 deletions docs/source/model_doc/vision-encoder-decoder.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ An example of how to use a [`VisionEncoderDecoderModel`] for inference can be se
- forward
- from_encoder_decoder_pretrained

## TFVisionEncoderDecoderModel

[[autodoc]] TFVisionEncoderDecoderModel
- call
- from_encoder_decoder_pretrained

## FlaxVisionEncoderDecoderModel

[[autodoc]] FlaxVisionEncoderDecoderModel
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,7 @@
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
Expand All @@ -1500,6 +1501,7 @@
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]
)
Expand Down Expand Up @@ -1838,6 +1840,7 @@
"TFTransfoXLPreTrainedModel",
]
)
_import_structure["models.vision_encoder_decoder"].extend(["TFVisionEncoderDecoderModel"])
_import_structure["models.vit"].extend(
[
"TFViTForImageClassification",
Expand Down Expand Up @@ -3343,6 +3346,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
Expand All @@ -3356,6 +3360,7 @@
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)
from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
Expand Down Expand Up @@ -3625,6 +3630,7 @@
TFTransfoXLModel,
TFTransfoXLPreTrainedModel,
)
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
33 changes: 23 additions & 10 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -628,14 +629,18 @@ def generate(
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"

# This block corresponds to the following line in `generation_utils`:
Copy link
Contributor

@patrickvonplaten patrickvonplaten Jan 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole function is in dire need of a refactor - I'll try to tackle this with @Rocketknight1 this month. Good for me the way it is now though

# "input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))"
# with the following differences:
# 1. In PT, `generate()`'s `model_kwargs` can accept `encoder_outputs`, but not the case in TF.
# 2. There is no shape checking in PT.
# In both PT/TF, if `input_ids` is `None`, we try to create it as it is for a text model.
if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = tf.fill((batch_size, 1), bos_token_id)
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."

# not allow to duplicate outputs when greedy decoding
if do_sample is False:
Expand Down Expand Up @@ -691,21 +696,29 @@ def generate(
# get encoder and store encoder outputs
encoder = self.get_encoder()

encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict_in_generate,
)
encoder_kwargs = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine with me!

"attention_mask": attention_mask,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict_in_generate,
}

# vision models don't use `attention_mask`.
ydshieh marked this conversation as resolved.
Show resolved Hide resolved
signature = dict(inspect.signature(encoder.call).parameters)
if "attention_mask" not in signature:
encoder_kwargs.pop("attention_mask")

encoder_outputs = encoder(input_ids, **encoder_kwargs)
if return_dict_in_generate:
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states

# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs.
# (vision inputs might occur when the model is an encoder-decoder model)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a bit hacky - vision inputs should also work with num_beams > 1 no? But ok for now until we do the big generate refactor @Rocketknight1

Copy link
Collaborator Author

@ydshieh ydshieh Dec 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as remark: The code inside this block treats input_ids as text-only inputs, for example shape_list(input_ids)[-1] assumes that the last dimension is the sequence dim.

For vision inputs, generate will be called only if it is a vision model as encoder in an encoder-decoder model (?). In this case, it is more the decoder_input_ids to be processed. And I think this is done in the next block

if self.config.is_encoder_decoder:

It's not clear to me if a standalone vision model will need to call generate(). Maybe @NielsRogge can share some insights here (ImageGPT ?).

input_ids_len = shape_list(input_ids)[-1]
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
"TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel",
Expand All @@ -100,6 +101,7 @@
"TFAutoModelForSequenceClassification",
"TFAutoModelForTableQuestionAnswering",
"TFAutoModelForTokenClassification",
"TFAutoModelForVision2Seq",
"TFAutoModelWithLMHead",
]

Expand Down Expand Up @@ -197,6 +199,7 @@
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel,
Expand All @@ -210,6 +213,7 @@
TFAutoModelForSequenceClassification,
TFAutoModelForTableQuestionAnswering,
TFAutoModelForTokenClassification,
TFAutoModelForVision2Seq,
TFAutoModelWithLMHead,
)

Expand Down
15 changes: 14 additions & 1 deletion src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@
]
)

TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)

TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
Expand All @@ -182,7 +188,6 @@
]
)


TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
Expand Down Expand Up @@ -327,6 +332,7 @@
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
Expand Down Expand Up @@ -387,6 +393,13 @@ class TFAutoModelForImageClassification(_BaseAutoModelClass):
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")


class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING


TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")


class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class TFEncoderDecoderModel(TFPreTrainedModel):
r"""
[`TFEncoderDecoder`] is a generic model class that will be instantiated as a transformer architecture with one of
the base model classes of the library as encoder and another one as decoder when created with the
:meth*~transformers.TFAutoModel.from_pretrained* class method for the encoder and
:meth*~transformers.TFAutoModelForCausalLM.from_pretrained* class method for the decoder.
[`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
of the base model classes of the library as encoder and another one as decoder when created with the
[`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
method for the decoder.
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
Expand Down Expand Up @@ -233,13 +233,6 @@ def dummy_inputs(self):
# Add `decoder_input_ids` because `self.decoder` requires it.
input_ids = tf.constant(DUMMY_INPUTS)
dummy = {"input_ids": input_ids, "decoder_input_ids": input_ids}
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
if self.config.add_cross_attention:
batch_size, seq_len = input_ids.shape
shape = (batch_size, seq_len) + (self.config.hidden_size,)
h = tf.random.uniform(shape=shape)
dummy["encoder_hidden_states"] = h

Comment on lines -236 to -242
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this part removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFEncoderDecoderModel.call() doesn't have encoder_hidden_states parameter, but encoder_outputs.
Moreover, encoder_hidden_states is always passed to the decoder with encoder_hidden_states = encoder_outputs[0]. Therefore, there is no need to add encoder_hidden_states in dummy_inputs.

(I don't remember why I did that before, probably in some intermediate commits, it was required)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

return dummy

def get_encoder(self):
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/vision_encoder_decoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_flax_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -28,6 +28,9 @@
if is_torch_available():
_import_structure["modeling_vision_encoder_decoder"] = ["VisionEncoderDecoderModel"]

if is_tf_available():
_import_structure["modeling_tf_vision_encoder_decoder"] = ["TFVisionEncoderDecoderModel"]

if is_flax_available():
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]

Expand All @@ -37,6 +40,9 @@
if is_torch_available():
from .modeling_vision_encoder_decoder import VisionEncoderDecoderModel

if is_tf_available():
from .modeling_tf_vision_encoder_decoder import TFVisionEncoderDecoderModel

if is_flax_available():
from .modeling_flax_vision_encoder_decoder import FlaxVisionEncoderDecoderModel

Expand Down
Loading