From d89dc406e0c4768f02a404baf9e4018e05f960b0 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 16 Jun 2022 13:14:12 +0100 Subject: [PATCH 01/16] Improve vision models --- src/transformers/models/deit/modeling_deit.py | 65 ++++++-------- src/transformers/models/swin/modeling_swin.py | 54 +++++++----- src/transformers/models/vit/modeling_vit.py | 88 +++++++++---------- tests/models/deit/test_modeling_deit.py | 21 +++++ tests/models/swin/test_modeling_swin.py | 38 ++++++-- tests/models/vit/test_modeling_vit.py | 21 +++++ 6 files changed, 174 insertions(+), 113 deletions(-) diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index ac429c0a615fc0..be5c5ab97bf877 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -61,21 +61,9 @@ ] -# Copied from transformers.models.vit.modeling_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - - class DeiTEmbeddings(nn.Module): """ Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token. - """ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None: @@ -84,22 +72,17 @@ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None: self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None - self.patch_embeddings = PatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = PatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor: embeddings = self.patch_embeddings(pixel_values) - batch_size, seq_len, _ = embeddings.size() + batch_size, seq_length, _ = embeddings.size() if bool_masked_pos is not None: - mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask @@ -114,30 +97,36 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo class PatchEmbeddings(nn.Module): """ - Image to Patch Embedding. - + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ - def __init__( - self, - image_size: int = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - num_channels: int = 3, - embed_dim: int = 768, - ) -> None: + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape - # FIXME look at relaxing size constraints + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." @@ -570,8 +559,8 @@ def forward(self, hidden_states): @add_start_docstrings( - "DeiT Model with a decoder on top for masked image modeling, as proposed in `SimMIM" - " `__.", + "DeiT Model with a decoder on top for masked image modeling, as proposed in" + " [SimMIM](https://arxiv.org/abs/2111.09886).", DEIT_START_DOCSTRING, ) class DeiTForMaskedImageModeling(DeiTPreTrainedModel): @@ -581,7 +570,11 @@ def __init__(self, config: DeiTConfig) -> None: self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True) self.decoder = nn.Sequential( - nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1), + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), nn.PixelShuffle(config.encoder_stride), ) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index be46b8dc2f8fce..d1943dddd49cf4 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -59,7 +59,7 @@ # See all Swin models at https://huggingface.co/models?filter=swin ] -# to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. +# drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. @dataclass @@ -203,13 +203,6 @@ class SwinImageClassifierOutput(ModelOutput): reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.vit.modeling_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - def window_partition(input_feature, window_size): """ Partitions the given input into windows. @@ -254,12 +247,7 @@ class SwinEmbeddings(nn.Module): def __init__(self, config, use_mask_token=False): super().__init__() - self.patch_embeddings = SwinPatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.embed_dim, - ) + self.patch_embeddings = SwinPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.patch_grid = self.patch_embeddings.grid_size self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None @@ -295,20 +283,29 @@ def forward( class SwinPatchEmbeddings(nn.Module): """ - Image to Patch Embedding. + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.embed_dim, + ) + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def maybe_pad(self, pixel_values, height, width): if width % self.patch_size[1] != 0: @@ -320,7 +317,11 @@ def maybe_pad(self, pixel_values, height, width): return pixel_values def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: - _, _, height, width = pixel_values.shape + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) embeddings = self.projection(pixel_values) @@ -408,7 +409,10 @@ def __init__(self, config, dim, num_heads): self.num_attention_heads = num_heads self.attention_head_size = int(dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.window_size = to_2tuple(config.window_size) + window_size = config.window_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) @@ -997,8 +1001,8 @@ def forward( @add_start_docstrings( - "Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM" - " `__.", + "Swin Model with a decoder on top for masked image modeling, as proposed in" + " [SimMIM](https://arxiv.org/abs/2111.09886).", SWIN_START_DOCSTRING, ) class SwinForMaskedImageModeling(SwinPreTrainedModel): @@ -1009,7 +1013,9 @@ def __init__(self, config): num_features = int(config.embed_dim * 2 ** (config.num_layers - 1)) self.decoder = nn.Sequential( - nn.Conv2d(in_channels=num_features, out_channels=config.encoder_stride**2 * 3, kernel_size=1), + nn.Conv2d( + in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1 + ), nn.PixelShuffle(config.encoder_stride), ) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index dde36b45ef5bb5..5b63a65de2c5dd 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -59,23 +59,9 @@ ] -# Inspired by -# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py -# From PyTorch internals -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - - class ViTEmbeddings(nn.Module): """ Construct the CLS token, position and patch embeddings. Optionally, also the mask token. - """ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: @@ -83,12 +69,7 @@ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None - self.patch_embeddings = PatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = ViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -103,9 +84,9 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 """ - npatch = embeddings.shape[1] - 1 - N = self.position_embeddings.shape[1] - 1 - if npatch == N and height == width: + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, 0] patch_pos_embed = self.position_embeddings[:, 1:] @@ -116,8 +97,10 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: # see discussion at https://github.com/facebookresearch/dino/issues/8 h0, w0 = h0 + 0.1, w0 + 0.1 patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), - scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)), + patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim).permute( + 0, 3, 1, 2 + ), + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), mode="bicubic", align_corners=False, ) @@ -134,9 +117,9 @@ def forward( batch_size, num_channels, height, width = pixel_values.shape embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: - mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) embeddings = embeddings * (1.0 - mask) + mask_tokens * mask @@ -156,41 +139,46 @@ def forward( return embeddings -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class PatchEmbeddings(nn.Module): +class ViTPatchEmbeddings(nn.Module): """ - Image to Patch Embedding. - + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ - def __init__( - self, - image_size: int = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - num_channels: int = 3, - embed_dim: int = 768, - ): + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if not interpolate_pos_encoding: if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size[0]}*{self.image_size[1]})." ) - x = self.projection(pixel_values).flatten(2).transpose(1, 2) - return x + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings class ViTSelfAttention(nn.Module): @@ -524,7 +512,7 @@ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_t # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> PatchEmbeddings: + def get_input_embeddings(self) -> ViTPatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: @@ -613,8 +601,8 @@ def forward(self, hidden_states): @add_start_docstrings( - "ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM" - " `__.", + "ViT Model with a decoder on top for masked image modeling, as proposed in" + " [SimMIM](https://arxiv.org/abs/2111.09886).", VIT_START_DOCSTRING, ) class ViTForMaskedImageModeling(ViTPreTrainedModel): @@ -624,7 +612,11 @@ def __init__(self, config: ViTConfig) -> None: self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True) self.decoder = nn.Sequential( - nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1), + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), nn.PixelShuffle(config.encoder_stride), ) diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index 4559fa0c7127bf..82afc70840d500 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -131,6 +131,23 @@ def create_and_check_model(self, config, pixel_values, labels): result = model(pixel_values) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): + model = DeiTForMaskedImageModeling(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + ) + + # test greyscale images + config.num_channels = 1 + model = DeiTForMaskedImageModeling(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size model = DeiTForImageClassification(config) @@ -208,6 +225,10 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_masked_image_modeling(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 0c1f266816c7c6..e21d4474508ed6 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -14,6 +14,7 @@ # limitations under the License. """ Testing suite for the PyTorch Swin model. """ +import collections import inspect import os import pickle @@ -33,7 +34,7 @@ from torch import nn from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel - from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): from PIL import Image @@ -44,6 +45,12 @@ from transformers.utils.fx import symbolic_trace +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + class SwinModelTester: def __init__( self, @@ -141,6 +148,23 @@ def create_and_check_model(self, config, pixel_values, labels): self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): + model = SwinForMaskedImageModeling(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + ) + + # test greyscale images + config.num_channels = 1 + model = SwinForMaskedImageModeling(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size model = SwinForImageClassification(config) @@ -198,6 +222,14 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_masked_image_modeling(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + def test_inputs_embeds(self): # Swin does not use inputs_embeds pass @@ -354,10 +386,6 @@ def test_hidden_states_output_with_padding(self): config.output_hidden_states = True self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) - def test_for_image_classification(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_image_classification(*config_and_inputs) - @slow def test_model_from_pretrained(self): for model_name in SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index bfca8bf5cb9aa5..4e2ccccf70f4ee 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -120,6 +120,23 @@ def create_and_check_model(self, config, pixel_values, labels): result = model(pixel_values) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): + model = ViTForMaskedImageModeling(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + ) + + # test greyscale images + config.num_channels = 1 + model = ViTForMaskedImageModeling(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size model = ViTForImageClassification(config) @@ -197,6 +214,10 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_masked_image_modeling(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) From 52d76e0734b2f22b90b946a5d61aea405491d16e Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 16 Jun 2022 15:23:18 +0100 Subject: [PATCH 02/16] Add a lot of improvements --- src/transformers/models/beit/modeling_beit.py | 62 ++++---- .../models/convnext/modeling_convnext.py | 31 ++-- src/transformers/models/cvt/modeling_cvt.py | 27 ++-- .../data2vec/modeling_data2vec_vision.py | 68 ++++---- src/transformers/models/dpt/modeling_dpt.py | 34 ++-- src/transformers/models/glpn/modeling_glpn.py | 27 ++-- .../models/maskformer/modeling_maskformer.py | 67 ++++---- .../models/poolformer/modeling_poolformer.py | 51 +++--- .../models/segformer/modeling_segformer.py | 27 ++-- src/transformers/models/swin/modeling_swin.py | 48 +++--- .../models/swin/modeling_tf_swin.py | 41 +++-- src/transformers/models/van/modeling_van.py | 7 +- src/transformers/models/vilt/modeling_vilt.py | 38 +++-- .../models/vit_mae/modeling_tf_vit_mae.py | 41 ++--- .../models/vit_mae/modeling_vit_mae.py | 42 +++-- .../models/yolos/modeling_yolos.py | 53 +++---- .../data2vec/test_modeling_data2vec_vision.py | 112 +------------ tests/models/vit_mae/test_modeling_vit_mae.py | 148 ++---------------- tests/models/yolos/test_modeling_yolos.py | 6 +- 19 files changed, 366 insertions(+), 564 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 3e16a081072352..2199be718a2222 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -91,17 +91,7 @@ class BeitModelOutputWithPooling(BaseModelOutputWithPooling): """ -# Inspired by -# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py -# From PyTorch internals -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - -# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py -def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). @@ -112,16 +102,16 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output -class DropPath(nn.Module): +class BeitDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: Optional[float] = None) -> None: @@ -151,12 +141,7 @@ def __init__(self, config: BeitConfig) -> None: self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) else: self.mask_token = None - self.patch_embeddings = PatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = BeitPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) @@ -184,38 +169,45 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo return embeddings -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class PatchEmbeddings(nn.Module): +class BeitPatchEmbeddings(nn.Module): """ Image to Patch Embedding. """ - def __init__( - self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768 - ) -> None: + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches self.patch_shape = patch_shape - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape - # FIXME look at relaxing size constraints + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) - x = self.projection(pixel_values).flatten(2).transpose(1, 2) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return x + return embeddings class BeitSelfAttention(nn.Module): @@ -393,7 +385,7 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop self.intermediate = BeitIntermediate(config) self.output = BeitOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) init_values = config.layer_scale_init_value diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index fc484627218a2a..4ce040a30435de 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -53,36 +53,41 @@ ] -# Stochastic depth implementation -# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py -def drop_path(x, drop_prob: float = 0.0, training: bool = False): +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input, drop_prob: float = 0.0, training: bool = False): """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the - DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop - Connect' is a different form of dropout in a separate paper... See discussion: - https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and - argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext class ConvNextDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None): + def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + class ConvNextLayerNorm(nn.Module): r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index ca6d3bd0b31411..2ef8bd7378c622 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -79,21 +79,23 @@ class BaseModelOutputWithCLSToken(ModelOutput): # Copied from transformers.models.convnext.modeling_convnext.drop_path -def drop_path(x, drop_prob: float = 0.0, training: bool = False): +def drop_path(input, drop_prob: float = 0.0, training: bool = False): """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the - DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop - Connect' is a different form of dropout in a separate paper... See discussion: - https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and - argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output @@ -101,13 +103,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): class CvtDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None): + def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + class CvtEmbeddings(nn.Module): """ diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 0e286a773d31ec..0cf0aae80d6649 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -91,18 +91,8 @@ class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling): """ -# Inspired by -# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py -# From PyTorch internals -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - -# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # Copied from transformers.models.beit.modeling_beit.drop_path -def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). @@ -113,17 +103,17 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) - argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output -# Copied from transformers.models.beit.modeling_beit.DropPath -class DropPath(nn.Module): +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision +class Data2VecVisionDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: Optional[float] = None) -> None: @@ -137,8 +127,6 @@ def extra_repr(self) -> str: return "p={}".format(self.drop_prob) -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision class Data2VecVisionEmbeddings(nn.Module): """ @@ -154,12 +142,7 @@ def __init__(self, config: Data2VecVisionConfig) -> None: self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) else: self.mask_token = None - self.patch_embeddings = PatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = Data2VecVisionPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) @@ -187,39 +170,46 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo return embeddings -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -# Copied from transformers.models.beit.modeling_beit.PatchEmbeddings -class PatchEmbeddings(nn.Module): +# Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision +class Data2VecVisionPatchEmbeddings(nn.Module): """ Image to Patch Embedding. """ - def __init__( - self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768 - ) -> None: + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches self.patch_shape = patch_shape - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape - # FIXME look at relaxing size constraints + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." ) - x = self.projection(pixel_values).flatten(2).transpose(1, 2) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return x + return embeddings # Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision @@ -405,7 +395,7 @@ def __init__( self.intermediate = Data2VecVisionIntermediate(config) self.output = Data2VecVisionOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) init_values = config.layer_scale_init_value diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 64ea40a5c534f1..c9700ac44fa501 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -65,13 +65,6 @@ ] -# Copied from transformers.models.vit.modeling_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - class DPTViTEmbeddings(nn.Module): """ Construct the CLS token, position and patch embeddings. @@ -82,12 +75,7 @@ def __init__(self, config): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.patch_embeddings = DPTViTPatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = DPTViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -138,19 +126,31 @@ class DPTViTPatchEmbeddings(nn.Module): """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index b7fc18b1d0f9bc..0575a54f6b58fb 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -54,21 +54,23 @@ # Copied from transformers.models.segformer.modeling_segformer.drop_path -def drop_path(x, drop_prob: float = 0.0, training: bool = False): +def drop_path(input, drop_prob: float = 0.0, training: bool = False): """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the - DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop - Connect' is a different form of dropout in a separate paper... See discussion: - https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and - argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output @@ -76,13 +78,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): class GLPNDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None): + def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + # Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings class GLPNOverlapPatchEmbeddings(nn.Module): diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 64c8d0029cfbcb..0616adcdfaafab 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -471,13 +471,6 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = return loss / height_and_width -# Copied from transformers.models.vit.modeling_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - # Copied from transformers.models.swin.modeling_swin.window_partition def window_partition(input_feature, window_size): """ @@ -506,15 +499,21 @@ def window_reverse(windows, window_size, height, width): def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. """ if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = input.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0 and scale_by_keep: - random_tensor.div_(keep_prob) - return input * random_tensor + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output class MaskFormerSwinEmbeddings(nn.Module): @@ -525,12 +524,7 @@ class MaskFormerSwinEmbeddings(nn.Module): def __init__(self, config): super().__init__() - self.patch_embeddings = MaskFormerSwinPatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.embed_dim, - ) + self.patch_embeddings = MaskFormerSwinPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.patch_grid = self.patch_embeddings.grid_size @@ -559,17 +553,25 @@ class MaskFormerSwinPatchEmbeddings(nn.Module): Image to Patch Embedding, including padding. """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.embed_dim, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def maybe_pad(self, pixel_values, height, width): if width % self.patch_size[1] != 0: @@ -581,7 +583,11 @@ def maybe_pad(self, pixel_values, height, width): return pixel_values def forward(self, pixel_values): - _, _, height, width = pixel_values.shape + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) embeddings = self.projection(pixel_values) @@ -649,13 +655,15 @@ def forward(self, input_feature, input_dimensions): class MaskFormerSwinDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None, scale_by_keep=True): - super(MaskFormerSwinDropPath, self).__init__() + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - def forward(self, input): - return drop_path(input, self.drop_prob, self.training, self.scale_by_keep) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin @@ -670,7 +678,10 @@ def __init__(self, config, dim, num_heads): self.num_attention_heads = num_heads self.attention_head_size = int(dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.window_size = to_2tuple(config.window_size) + window_size = config.window_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 2335c7cdc40c36..b4036b506f5a85 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -50,40 +50,41 @@ ] -# Copied from transformers.models.vit.modeling_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is - misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: - https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and - argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. +# Copied from transformers.models.convnext.modeling_convnext.drop_path +def drop_path(input, drop_prob: float = 0.0, training: bool = False): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output +# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->PoolFormer class PoolFormerDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None): + def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + class PoolFormerEmbeddings(nn.Module): """ @@ -92,17 +93,17 @@ class PoolFormerEmbeddings(nn.Module): def __init__(self, hidden_size, num_channels, patch_size, stride, padding, norm_layer=None): super().__init__() - patch_size = to_2tuple(patch_size) - stride = to_2tuple(stride) - padding = to_2tuple(padding) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride) + padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=padding) self.norm = norm_layer(hidden_size) if norm_layer else nn.Identity() def forward(self, pixel_values): - x = self.projection(pixel_values) - x = self.norm(x) - return x + embeddings = self.projection(pixel_values) + embeddings = self.norm(embeddings) + return embeddings class PoolFormerGroupNorm(nn.GroupNorm): diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 55ac976b354422..22bf9234b2aec0 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -86,21 +86,23 @@ class SegFormerImageClassifierOutput(ImageClassifierOutput): # Copied from transformers.models.convnext.modeling_convnext.drop_path -def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True): +def drop_path(input, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True): """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the - DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop - Connect' is a different form of dropout in a separate paper... See discussion: - https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and - argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output @@ -108,13 +110,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=T class SegformerDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None): + def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + class SegformerOverlapPatchEmbeddings(nn.Module): """Construct the overlapping patch embeddings.""" diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index d1943dddd49cf4..f8f9342eed694c 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -225,20 +225,6 @@ def window_reverse(windows, window_size, height, width): return windows -def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - """ - if drop_prob == 0.0 or not training: - return input - keep_prob = 1 - drop_prob - shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = input.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0 and scale_by_keep: - random_tensor.div_(keep_prob) - return input * random_tensor - - class SwinEmbeddings(nn.Module): """ Construct the patch and position embeddings. Optionally, also the mask token. @@ -386,16 +372,40 @@ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int] return input_feature +# Copied from transformers.models.convnext.modeling_convnext.drop_path +def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Swin class SwinDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None, scale_by_keep=True): - super(SwinDropPath, self).__init__() + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - def forward(self, input): - return drop_path(input, self.drop_prob, self.training, self.scale_by_keep) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) class SwinSelfAttention(nn.Module): diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 5b9ecbeccfafb2..5203d81c49d586 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -63,7 +63,7 @@ # See all Swin models at https://huggingface.co/models?filter=swin ] -# to_2tuple, drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow +# drop_path, TFSwinPatchEmbeddings, TFSwinPatchMerging and TFSwinDropPath are tensorflow # implementations of PyTorch functionalities in the timm library. @@ -208,13 +208,6 @@ class TFSwinImageClassifierOutput(ModelOutput): reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None -# Copied from transformers.models.vit.modeling_tf_vit.to_2tuple -def to_2tuple(x) -> Tuple[Any, Any]: - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: """ Partitions the given input into windows. @@ -329,20 +322,29 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer): Image to Patch Embedding. """ - def __init__( - self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768, **kwargs - ) -> None: + def __init__(self, config, **kwargs): super().__init__(**kwargs) - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.embed_dim, + ) + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.projection = tf.keras.layers.Conv2D( - filters=embed_dim, kernel_size=self.patch_size, strides=self.patch_size, padding="valid", name="projection" + filters=hidden_size, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + name="projection", ) def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tensor: @@ -355,7 +357,11 @@ def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tens return pixel_values def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]: - _, _, height, width = shape_list(pixel_values) + _, num_channels, height, width = shape_list(pixel_values) + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) # pad the input to be divisible by self.patch_size, if needed pixel_values = self.maybe_pad(pixel_values, height, width) @@ -460,7 +466,10 @@ def __init__(self, config: SwinConfig, dim: int, num_heads: int, **kwargs) -> No self.num_attention_heads = num_heads self.attention_head_size = int(dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.window_size = to_2tuple(config.window_size) + window_size = config.window_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) # get pair-wise relative position index for each token inside the window coords_h = tf.range(self.window_size[0]) diff --git a/src/transformers/models/van/modeling_van.py b/src/transformers/models/van/modeling_van.py index b94f9969ba9d28..05de41360998dc 100644 --- a/src/transformers/models/van/modeling_van.py +++ b/src/transformers/models/van/modeling_van.py @@ -56,7 +56,7 @@ # Stochastic depth implementation # Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py -def drop_path(x, drop_prob: float = 0.0, training: bool = False): +def drop_path(input, drop_prob: float = 0.0, training: bool = False): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop @@ -78,13 +78,16 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False): class VanDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - def __init__(self, drop_prob=None): + def __init__(self, drop_prob: Optional[float] = None) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: return drop_path(x, self.drop_prob, self.training) + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + class VanOverlappingPatchEmbedder(nn.Module): """ diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 174799318a80d8..77e5487c8cbfa3 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -82,13 +82,6 @@ class ViltForImagesAndTextClassificationOutput(ModelOutput): attentions: Optional[List[Tuple[torch.FloatTensor]]] = None -# Copied from transformers.models.vit.modeling_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - class ViltEmbeddings(nn.Module): """ Construct the text and patch embeddings. @@ -105,12 +98,7 @@ def __init__(self, config): self.text_embeddings = TextEmbeddings(config) # patch embeddings self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.patch_embeddings = PatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = ViltPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) # modality type (text/patch) embeddings @@ -304,26 +292,36 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs return embeddings -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class PatchEmbeddings(nn.Module): +class ViltPatchEmbeddings(nn.Module): """ Image to Patch Embedding. """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) x = self.projection(pixel_values) return x diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index 24d84141c895ba..efb59b2217c2a2 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -133,13 +133,6 @@ class TFViTMAEForPreTrainingOutput(ModelOutput): attentions: Optional[Tuple[tf.Tensor]] = None -# copied from transformers.models.vit.modeling_tf_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): """ Create 2D sin/cos positional embeddings. @@ -212,7 +205,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer): def __init__(self, config: ViTMAEConfig, **kwargs): super().__init__(**kwargs) - self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings") + self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings") self.num_patches = self.patch_embeddings.num_patches self.config = config @@ -297,30 +290,34 @@ def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor: return embeddings, mask, ids_restore -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class TFPatchEmbeddings(tf.keras.layers.Layer): +class TFViTMAEPatchEmbeddings(tf.keras.layers.Layer): """ - Image to Patch Embedding. - + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ def __init__(self, config: ViTMAEConfig, **kwargs): super().__init__(**kwargs) - image_size = to_2tuple(config.image_size) - patch_size = to_2tuple(config.patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_patches = num_patches - self.num_channels = config.num_channels - self.embed_dim = config.hidden_size + self.num_channels = num_channels self.config = config self.projection = tf.keras.layers.Conv2D( - filters=self.embed_dim, - kernel_size=self.patch_size, - strides=self.patch_size, + filters=hidden_size, + kernel_size=patch_size, + strides=patch_size, padding="valid", data_format="channels_last", kernel_initializer="glorot_uniform", # following torch.nn.Linear @@ -330,6 +327,10 @@ def __init__(self, config: ViTMAEConfig, **kwargs): def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if getattr(height, "numpy", None) and getattr(width, "numpy", None): if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 1226b6025c8918..34403787d99a69 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -135,13 +135,6 @@ class ViTMAEForPreTrainingOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# copied from transformers.models.vit.modeling_vit.to_2tuple ViT->ViTMAE -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): """ Create 2D sin/cos positional embeddings. @@ -213,12 +206,7 @@ def __init__(self, config): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.patch_embeddings = PatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = ViTMAEPatchEmbeddings(config) self.num_patches = self.patch_embeddings.num_patches # fixed sin-cos embedding self.position_embeddings = nn.Parameter( @@ -291,27 +279,37 @@ def forward(self, pixel_values, noise=None): return embeddings, mask, ids_restore -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class PatchEmbeddings(nn.Module): +class ViTMAEPatchEmbeddings(nn.Module): """ - Image to Patch Embedding. - + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ - def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values): batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 578e8ca6092794..af6dc35ab00ce4 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -111,13 +111,6 @@ class YolosObjectDetectionOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.vit.modeling_vit.to_2tuple -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - class YolosEmbeddings(nn.Module): """ Construct the CLS token, detection tokens, position and patch embeddings. @@ -129,12 +122,7 @@ def __init__(self, config: YolosConfig) -> None: self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.detection_tokens = nn.Parameter(torch.zeros(1, config.num_detection_tokens, config.hidden_size)) - self.patch_embeddings = PatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.hidden_size, - ) + self.patch_embeddings = YolosPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter( torch.zeros(1, num_patches + config.num_detection_tokens + 1, config.hidden_size) @@ -228,32 +216,39 @@ def forward(self, pos_embed, img_size=(800, 1344)) -> torch.Tensor: return scale_pos_embed -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class PatchEmbeddings(nn.Module): +class YolosPatchEmbeddings(nn.Module): """ - Image to Patch Embedding. - + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ - def __init__( - self, - image_size: int = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - num_channels: int = 3, - embed_dim: int = 768, - ): + def __init__(self, config): super().__init__() - image_size = to_2tuple(image_size) - patch_size = to_2tuple(patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size + self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) return embeddings @@ -620,7 +615,7 @@ def __init__(self, config: YolosConfig, add_pooling_layer: bool = True): # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> PatchEmbeddings: + def get_input_embeddings(self) -> YolosPatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index 8966b909970a28..337606510af220 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -39,7 +39,6 @@ ) from transformers.models.data2vec.modeling_data2vec_vision import ( DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST, - to_2tuple, ) @@ -94,6 +93,10 @@ def __init__( self.out_indices = out_indices self.num_labels = num_labels + # in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -131,9 +134,7 @@ def create_and_check_model(self, config, pixel_values, labels, pixel_labels): model.eval() result = model(pixel_values) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + num_patches = (self.image_size // self.patch_size)**2 self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): @@ -286,109 +287,6 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - def test_attention_outputs(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True - - # in Data2VecVision, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches + 1 - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes - - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - attentions = outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - out_len = len(outputs) - - # Check attention is always last and order is fine - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - self.assertEqual(out_len + 1, len(outputs)) - - self_attentions = outputs.attentions - - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - - def test_hidden_states_output(self): - def check_hidden_states_output(inputs_dict, config, model_class): - model = model_class(config) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states - - expected_num_layers = getattr( - self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 - ) - self.assertEqual(len(hidden_states), expected_num_layers) - - # Data2VecVision has a different seq_length - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_length = num_patches + 1 - - self.assertListEqual( - list(hidden_states[0].shape[-2:]), - [seq_length, self.model_tester.hidden_size], - ) - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - inputs_dict["output_hidden_states"] = True - check_hidden_states_output(inputs_dict, config, model_class) - - # check that output_hidden_states also work using config - del inputs_dict["output_hidden_states"] - config.output_hidden_states = True - - check_hidden_states_output(inputs_dict, config, model_class) - def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None): # We override with a slightly higher tol value, as semseg models tend to diverge a bit more super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes) diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index 1a749b162893e9..108aff0a1a2500 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -35,7 +35,7 @@ from torch import nn from transformers import ViTMAEForPreTraining, ViTMAEModel - from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -64,6 +64,7 @@ def __init__( type_sequence_label_size=10, initializer_range=0.02, num_labels=3, + mask_ratio=0.6, scope=None, ): self.parent = parent @@ -82,8 +83,14 @@ def __init__( self.attention_probs_dropout_prob = attention_probs_dropout_prob self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range + self.mask_ratio = mask_ratio self.scope = scope + # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above + # (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size)**2 + self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1))) + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -109,6 +116,7 @@ def get_config(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + mask_ratio=self.mask_ratio, ) def create_and_check_model(self, config, pixel_values, labels): @@ -116,26 +124,16 @@ def create_and_check_model(self, config, pixel_values, labels): model.to(torch_device) model.eval() result = model(pixel_values) - # expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above - # (we add 1 for the [CLS] token) - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size)) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) def create_and_check_for_pretraining(self, config, pixel_values, labels): model = ViTMAEForPreTraining(config) model.to(torch_device) model.eval() result = model(pixel_values) - # expected sequence length = num_patches - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - expected_seq_len = num_patches + num_patches = (self.image_size // self.patch_size)**2 expected_num_channels = self.patch_size**2 * self.num_channels - self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels)) # test greyscale images config.num_channels = 1 @@ -175,8 +173,8 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() + @unittest.skip(reason="ViTMAE does not use inputs_embeds") def test_inputs_embeds(self): - # ViTMAE does not use inputs_embeds pass def test_model_common_attributes(self): @@ -208,126 +206,6 @@ def test_for_pretraining(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_pretraining(*config_and_inputs) - def test_attention_outputs(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True - - # in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes - - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - out_len = len(outputs) - - # Check attention is always last and order is fine - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = True - model = model_class(config) - model.to(torch_device) - model.eval() - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) - - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - - def test_hidden_states_output(self): - def check_hidden_states_output(inputs_dict, config, model_class): - model = model_class(config) - model.to(torch_device) - model.eval() - - with torch.no_grad(): - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states - - expected_num_layers = getattr( - self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 - ) - self.assertEqual(len(hidden_states), expected_num_layers) - - # ViTMAE has a different seq_length - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) - - self.assertListEqual( - list(hidden_states[0].shape[-2:]), - [seq_length, self.model_tester.hidden_size], - ) - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - inputs_dict["output_hidden_states"] = True - check_hidden_states_output(inputs_dict, config, model_class) - - # check that output_hidden_states also work using config - del inputs_dict["output_hidden_states"] - config.output_hidden_states = True - - check_hidden_states_output(inputs_dict, config, model_class) - # overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict): diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py index 75d399eaa7972e..8e67b4ce5fafc1 100644 --- a/tests/models/yolos/test_modeling_yolos.py +++ b/tests/models/yolos/test_modeling_yolos.py @@ -31,7 +31,7 @@ from torch import nn from transformers import YolosForObjectDetection, YolosModel - from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + from transformers.models.yolos.modeling_yolos import YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -86,9 +86,7 @@ def __init__( self.num_detection_tokens = num_detection_tokens # we set the expected sequence length (which is used in several tests) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + num_patches = (image_size[0] // patch_size)**2 self.expected_seq_len = num_patches + 1 + self.num_detection_tokens def prepare_config_and_inputs(self): From 8c1f418f400e9319c38092d961050242c7036faa Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Thu, 16 Jun 2022 15:23:56 +0000 Subject: [PATCH 03/16] Remove to_2tuple from swin tests --- tests/models/swin/test_modeling_swin.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index e21d4474508ed6..308c9a871c6d3c 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -45,12 +45,6 @@ from transformers.utils.fx import symbolic_trace -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - class SwinModelTester: def __init__( self, @@ -160,6 +154,8 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label # test greyscale images config.num_channels = 1 model = SwinForMaskedImageModeling(config) + model.to(torch_device) + model.eval() pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) @@ -331,7 +327,7 @@ def check_hidden_states_output(self, inputs_dict, config, model_class, image_siz self.assertEqual(len(hidden_states), expected_num_layers) # Swin has a different seq_length - patch_size = to_2tuple(config.patch_size) + patch_size = config.patch_size if isinstance(config.patch_size, collections.abc.Iterable) else (config.patch_size, config.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) @@ -355,7 +351,7 @@ def check_hidden_states_output(self, inputs_dict, config, model_class, image_siz def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - image_size = to_2tuple(self.model_tester.image_size) + image_size = self.model_tester.image_size if isinstance(self.model_tester.image_size, collections.abc.Iterable) else (self.model_tester.image_size, self.model_tester.image_size) for model_class in self.all_model_classes: inputs_dict["output_hidden_states"] = True @@ -371,8 +367,8 @@ def test_hidden_states_output_with_padding(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.patch_size = 3 - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(config.patch_size) + image_size = self.model_tester.image_size if isinstance(self.model_tester.image_size, collections.abc.Iterable) else (self.model_tester.image_size, self.model_tester.image_size) + patch_size = config.patch_size if isinstance(config.patch_size, collections.abc.Iterable) else (config.patch_size, config.patch_size) padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) From 1c6aa775c0764329142b5886d789fcb2e3e84849 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Fri, 17 Jun 2022 17:25:08 +0000 Subject: [PATCH 04/16] Fix TF Swin --- src/transformers/models/swin/modeling_tf_swin.py | 8 +------- tests/models/swin/test_modeling_tf_swin.py | 8 +++++++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 5203d81c49d586..557b057f7caedc 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -263,13 +263,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer): def __init__(self, config: SwinConfig, use_mask_token: bool = False, **kwargs) -> None: super().__init__(**kwargs) - self.patch_embeddings = TFSwinPatchEmbeddings( - image_size=config.image_size, - patch_size=config.patch_size, - num_channels=config.num_channels, - embed_dim=config.embed_dim, - name="patch_embeddings", - ) + self.patch_embeddings = TFSwinPatchEmbeddings(config, name="patch_embeddings") self.num_patches = self.patch_embeddings.num_patches self.patch_grid = self.patch_embeddings.grid_size self.embed_dim = config.embed_dim diff --git a/tests/models/swin/test_modeling_tf_swin.py b/tests/models/swin/test_modeling_tf_swin.py index 88323d7fd7a594..dc06993f217e2a 100644 --- a/tests/models/swin/test_modeling_tf_swin.py +++ b/tests/models/swin/test_modeling_tf_swin.py @@ -17,6 +17,7 @@ import inspect import unittest +import collections import numpy as np @@ -36,10 +37,15 @@ TFSwinForImageClassification, TFSwinForMaskedImageModeling, TFSwinModel, - to_2tuple, ) +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + if is_vision_available(): from PIL import Image From fef92c6dd34f3a68cb378332ec3eb87c30fdc292 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Sat, 18 Jun 2022 09:27:46 +0000 Subject: [PATCH 05/16] Fix more tests --- src/transformers/testing_utils.py | 7 + tests/models/beit/test_modeling_flax_beit.py | 1 - .../data2vec/test_modeling_data2vec_vision.py | 6 +- tests/models/deit/test_modeling_deit.py | 2 + tests/models/swin/test_modeling_swin.py | 24 +++- tests/models/swin/test_modeling_tf_swin.py | 9 +- tests/models/vit/test_modeling_vit.py | 2 + .../vit_mae/test_modeling_tf_vit_mae.py | 133 ++---------------- tests/models/vit_mae/test_modeling_vit_mae.py | 4 +- tests/models/yolos/test_modeling_yolos.py | 2 +- 10 files changed, 48 insertions(+), 142 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b73f3c380aeaa2..211f94f1f98efa 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import contextlib import inspect import logging @@ -1534,3 +1535,9 @@ def check_json_file_has_correct_format(file_path): left_indent = len(lines[1]) - len(lines[1].lstrip()) assert left_indent == 2 assert lines[-1].strip() == "}" + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) diff --git a/tests/models/beit/test_modeling_flax_beit.py b/tests/models/beit/test_modeling_flax_beit.py index 50996dedc7af52..ff10e536bf48a0 100644 --- a/tests/models/beit/test_modeling_flax_beit.py +++ b/tests/models/beit/test_modeling_flax_beit.py @@ -105,7 +105,6 @@ def prepare_config_and_inputs(self): return config, pixel_values, labels def create_and_check_model(self, config, pixel_values, labels): - model = FlaxBeitModel(config=config) result = model(pixel_values) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index 337606510af220..99ece4d58502b3 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -37,9 +37,7 @@ Data2VecVisionForSemanticSegmentation, Data2VecVisionModel, ) - from transformers.models.data2vec.modeling_data2vec_vision import ( - DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST, - ) + from transformers.models.data2vec.modeling_data2vec_vision import DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST if is_vision_available(): @@ -134,7 +132,7 @@ def create_and_check_model(self, config, pixel_values, labels, pixel_labels): model.eval() result = model(pixel_values) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) - num_patches = (self.image_size // self.patch_size)**2 + num_patches = (self.image_size // self.patch_size) ** 2 self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels): diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index 82afc70840d500..b163294cedf8c4 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -143,6 +143,8 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label # test greyscale images config.num_channels = 1 model = DeiTForMaskedImageModeling(config) + model.to(torch_device) + model.eval() pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 308c9a871c6d3c..33080ed47ffd50 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -327,7 +327,11 @@ def check_hidden_states_output(self, inputs_dict, config, model_class, image_siz self.assertEqual(len(hidden_states), expected_num_layers) # Swin has a different seq_length - patch_size = config.patch_size if isinstance(config.patch_size, collections.abc.Iterable) else (config.patch_size, config.patch_size) + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) @@ -351,7 +355,11 @@ def check_hidden_states_output(self, inputs_dict, config, model_class, image_siz def test_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - image_size = self.model_tester.image_size if isinstance(self.model_tester.image_size, collections.abc.Iterable) else (self.model_tester.image_size, self.model_tester.image_size) + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) for model_class in self.all_model_classes: inputs_dict["output_hidden_states"] = True @@ -367,8 +375,16 @@ def test_hidden_states_output_with_padding(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.patch_size = 3 - image_size = self.model_tester.image_size if isinstance(self.model_tester.image_size, collections.abc.Iterable) else (self.model_tester.image_size, self.model_tester.image_size) - patch_size = config.patch_size if isinstance(config.patch_size, collections.abc.Iterable) else (config.patch_size, config.patch_size) + image_size = ( + self.model_tester.image_size + if isinstance(self.model_tester.image_size, collections.abc.Iterable) + else (self.model_tester.image_size, self.model_tester.image_size) + ) + patch_size = ( + config.patch_size + if isinstance(config.patch_size, collections.abc.Iterable) + else (config.patch_size, config.patch_size) + ) padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) diff --git a/tests/models/swin/test_modeling_tf_swin.py b/tests/models/swin/test_modeling_tf_swin.py index dc06993f217e2a..aa8d425906b684 100644 --- a/tests/models/swin/test_modeling_tf_swin.py +++ b/tests/models/swin/test_modeling_tf_swin.py @@ -17,12 +17,11 @@ import inspect import unittest -import collections import numpy as np from transformers import SwinConfig -from transformers.testing_utils import require_tf, require_vision, slow +from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple from transformers.utils import cached_property, is_tf_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -40,12 +39,6 @@ ) -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - if is_vision_available(): from PIL import Image diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index 4e2ccccf70f4ee..f5bc6cc5d36ceb 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -132,6 +132,8 @@ def create_and_check_for_masked_image_modeling(self, config, pixel_values, label # test greyscale images config.num_channels = 1 model = ViTForMaskedImageModeling(config) + model.to(torch_device) + model.eval() pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index 7ce6a80098baa4..06c965a4a930f3 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -38,7 +38,6 @@ import tensorflow as tf from transformers import TFViTMAEForPreTraining, TFViTMAEModel - from transformers.models.vit_mae.modeling_tf_vit_mae import to_2tuple if is_vision_available(): @@ -67,6 +66,7 @@ def __init__( type_sequence_label_size=10, initializer_range=0.02, num_labels=3, + mask_ratio=0.6, scope=None, ): self.parent = parent @@ -85,8 +85,14 @@ def __init__( self.attention_probs_dropout_prob = attention_probs_dropout_prob self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range + self.mask_ratio = mask_ratio self.scope = scope + # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above + # (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1))) + def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -116,29 +122,21 @@ def get_config(self): attention_probs_dropout_prob=self.attention_probs_dropout_prob, is_decoder=False, initializer_range=self.initializer_range, + mask_ratio=self.mask_ratio, ) def create_and_check_model(self, config, pixel_values, labels): model = TFViTMAEModel(config=config) result = model(pixel_values, training=False) - # expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above - # (we add 1 for the [CLS] token) - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size)) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) def create_and_check_for_pretraining(self, config, pixel_values, labels): model = TFViTMAEForPreTraining(config) result = model(pixel_values, training=False) # expected sequence length = num_patches - image_size = to_2tuple(self.image_size) - patch_size = to_2tuple(self.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - expected_seq_len = num_patches + num_patches = (self.image_size // self.patch_size) ** 2 expected_num_channels = self.patch_size**2 * self.num_channels - self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels)) # test greyscale images config.num_channels = 1 @@ -179,7 +177,6 @@ def test_config(self): @unittest.skip(reason="ViTMAE does not use inputs_embeds") def test_inputs_embeds(self): - # ViTMAE does not use inputs_embeds pass def test_model_common_attributes(self): @@ -266,114 +263,6 @@ def prepare_numpy_arrays(inputs_dict): output_for_kw_input = model(**inputs_np, noise=noise) self.assert_outputs_same(output_for_dict_input, output_for_kw_input) - def test_attention_outputs(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True - - # in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes - - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - out_len = len(outputs) - - # Check attention is always last and order is fine - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = True - model = model_class(config) - outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) - - if hasattr(self.model_tester, "num_hidden_states_types"): - added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 - else: - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) - - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], - ) - - def test_hidden_states_output(self): - def check_hidden_states_output(inputs_dict, config, model_class): - model = model_class(config) - - outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states - - expected_num_layers = getattr( - self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 - ) - self.assertEqual(len(hidden_states), expected_num_layers) - - # ViTMAE has a different seq_length - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1))) - - self.assertListEqual( - list(hidden_states[0].shape[-2:]), - [seq_length, self.model_tester.hidden_size], - ) - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - inputs_dict["output_hidden_states"] = True - check_hidden_states_output(inputs_dict, config, model_class) - - # check that output_hidden_states also work using config - del inputs_dict["output_hidden_states"] - config.output_hidden_states = True - - check_hidden_states_output(inputs_dict, config, model_class) - # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index 108aff0a1a2500..bddf2c115a136b 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -88,7 +88,7 @@ def __init__( # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above # (we add 1 for the [CLS] token) - num_patches = (image_size // patch_size)**2 + num_patches = (image_size // patch_size) ** 2 self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1))) def prepare_config_and_inputs(self): @@ -131,7 +131,7 @@ def create_and_check_for_pretraining(self, config, pixel_values, labels): model.to(torch_device) model.eval() result = model(pixel_values) - num_patches = (self.image_size // self.patch_size)**2 + num_patches = (self.image_size // self.patch_size) ** 2 expected_num_channels = self.patch_size**2 * self.num_channels self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels)) diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py index 8e67b4ce5fafc1..22cda4f7a81da2 100644 --- a/tests/models/yolos/test_modeling_yolos.py +++ b/tests/models/yolos/test_modeling_yolos.py @@ -86,7 +86,7 @@ def __init__( self.num_detection_tokens = num_detection_tokens # we set the expected sequence length (which is used in several tests) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens - num_patches = (image_size[0] // patch_size)**2 + num_patches = (image_size[0] // patch_size) ** 2 self.expected_seq_len = num_patches + 1 + self.num_detection_tokens def prepare_config_and_inputs(self): From 6993ecd5407720dd50d4bf22e76dd2ce03bfd67e Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Sat, 18 Jun 2022 12:15:03 +0000 Subject: [PATCH 06/16] Fix copies --- src/transformers/models/deit/modeling_deit.py | 6 ++--- src/transformers/models/van/modeling_van.py | 23 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index be5c5ab97bf877..b73f9c36a38ff0 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -72,7 +72,7 @@ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None: self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None - self.patch_embeddings = PatchEmbeddings(config) + self.patch_embeddings = DeiTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -95,7 +95,7 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo return embeddings -class PatchEmbeddings(nn.Module): +class DeiTPatchEmbeddings(nn.Module): """ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a @@ -472,7 +472,7 @@ def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_ # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> PatchEmbeddings: + def get_input_embeddings(self) -> DeiTPatchEmbeddings: return self.embeddings.patch_embeddings def _prune_heads(self, heads_to_prune): diff --git a/src/transformers/models/van/modeling_van.py b/src/transformers/models/van/modeling_van.py index 05de41360998dc..5e212d5f485d3e 100644 --- a/src/transformers/models/van/modeling_van.py +++ b/src/transformers/models/van/modeling_van.py @@ -54,23 +54,24 @@ ] -# Stochastic depth implementation -# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py +# Copied from transformers.models.convnext.modeling_convnext.drop_path def drop_path(input, drop_prob: float = 0.0, training: bool = False): """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the - DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop - Connect' is a different form of dropout in a separate paper... See discussion: - https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and - argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. """ if drop_prob == 0.0 or not training: - return x + return input keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor + output = input.div(keep_prob) * random_tensor return output From 33b720a4b6f3266f097a27a3d8a8dab989031ace Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Wed, 22 Jun 2022 14:03:06 +0000 Subject: [PATCH 07/16] Improve more models --- .../models/convnext/modeling_convnext.py | 6 ++ src/transformers/models/cvt/modeling_cvt.py | 4 +- .../data2vec/modeling_tf_data2vec_vision.py | 42 ++++++-------- .../models/levit/modeling_levit.py | 6 ++ .../models/poolformer/modeling_poolformer.py | 4 +- .../models/regnet/modeling_regnet.py | 10 +++- .../models/resnet/modeling_resnet.py | 16 ++++-- src/transformers/models/swin/modeling_swin.py | 4 +- .../models/vit/modeling_tf_vit.py | 55 +++++++++---------- .../models/vit_mae/modeling_tf_vit_mae.py | 4 -- 10 files changed, 80 insertions(+), 71 deletions(-) diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index 4ce040a30435de..e9274f1e54d111 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -127,8 +127,14 @@ def __init__(self, config): config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size ) self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first") + self.num_channels = config.num_channels def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) embeddings = self.patch_embeddings(pixel_values) embeddings = self.layernorm(embeddings) return embeddings diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index 2ef8bd7378c622..8f71ed5f8d9635 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -78,7 +78,7 @@ class BaseModelOutputWithCLSToken(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.convnext.modeling_convnext.drop_path +# Copied from transformers.models.beit.modeling_beit.drop_path def drop_path(input, drop_prob: float = 0.0, training: bool = False): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). @@ -99,7 +99,7 @@ def drop_path(input, drop_prob: float = 0.0, training: bool = False): return output -# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath +# Copied from transformers.models.beit.modeling_beit.BeitDropPath class CvtDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index e7cc7d2449e75b..811580f4bbb2aa 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -100,7 +100,7 @@ class TFData2VecVisionModelOutputWithPooling(TFBaseModelOutputWithPooling): attentions: Optional[Tuple[tf.Tensor]] = None -class TFDropPath(tf.keras.layers.Layer): +class TFData2VecVisionDropPath(tf.keras.layers.Layer): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). References: (1) github.com:rwightman/pytorch-image-models @@ -120,8 +120,6 @@ def call(self, x, training=None): return x -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py class TFData2VecVisionEmbeddings(tf.keras.layers.Layer): """ Construct the CLS token, position and patch embeddings. Optionally, also the mask token. @@ -132,9 +130,7 @@ def __init__(self, config: Data2VecVisionConfig, **kwargs): super().__init__(**kwargs) self.config = config - self.patch_embeddings = TFPatchEmbeddings( - config=config, image_size=config.image_size, patch_size=config.patch_size, name="patch_embeddings" - ) + self.patch_embeddings = TFData2VecVisionPatchEmbeddings(config, name="patch_embeddings") self.num_patches = self.patch_embeddings.num_patches self.config = config @@ -192,40 +188,36 @@ def call(self, pixel_values: tf.Tensor, bool_masked_pos: Optional[tf.Tensor] = N return embeddings -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class TFPatchEmbeddings(tf.keras.layers.Layer): +class TFData2VecVisionPatchEmbeddings(tf.keras.layers.Layer): """ Image to Patch Embedding. """ - def __init__(self, config: Data2VecVisionConfig, image_size: int = 224, patch_size: int = 16, **kwargs): + def __init__(self, config: Data2VecVisionConfig, **kwargs): super().__init__(**kwargs) self.config = config - image_size = ( - config.image_size - if isinstance(config.image_size, collections.abc.Iterable) - else (config.image_size, config.image_size) - ) - patch_size = ( - config.patch_size - if isinstance(config.patch_size, collections.abc.Iterable) - else (config.patch_size, config.patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.image_size = image_size self.patch_size = patch_size self.num_patches = num_patches self.patch_shape = patch_shape - self.num_channels = config.num_channels - self.embed_dim = config.hidden_size + self.num_channels = num_channels self.projection = tf.keras.layers.Conv2D( - filters=self.embed_dim, - kernel_size=self.patch_size, - strides=self.patch_size, + filters=hidden_size, + kernel_size=patch_size, + strides=patch_size, padding="valid", data_format="channels_last", kernel_initializer="glorot_uniform", # following torch.nn.Linear @@ -465,7 +457,7 @@ def __init__( # Using `layers.Activation` instead of `tf.identity` to better control `training` # behaviour. self.drop_path = ( - TFDropPath(drop_path_rate, name="drop_path") + TFData2VecVisionDropPath(drop_path_rate, name="drop_path") if drop_path_rate > 0.0 else tf.keras.layers.Activation("linear", name="drop_path") ) diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index b04a98317d7899..581edf7d7c6c95 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -126,8 +126,14 @@ def __init__(self, config): self.embedding_layer_4 = LevitConvEmbeddings( config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding ) + self.num_channels = config.num_channels def forward(self, pixel_values): + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) embeddings = self.embedding_layer_1(pixel_values) embeddings = self.activation_layer_1(embeddings) embeddings = self.embedding_layer_2(embeddings) diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index b4036b506f5a85..b53c482da47b8f 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -50,7 +50,7 @@ ] -# Copied from transformers.models.convnext.modeling_convnext.drop_path +# Copied from transformers.models.beit.modeling_beit.drop_path def drop_path(input, drop_prob: float = 0.0, training: bool = False): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). @@ -71,7 +71,7 @@ def drop_path(input, drop_prob: float = 0.0, training: bool = False): return output -# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->PoolFormer +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->PoolFormer class PoolFormerDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 8d8098caf1ea14..0b343d6b3e4d13 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -93,9 +93,15 @@ def __init__(self, config: RegNetConfig): self.embedder = RegNetConvLayer( config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act ) + self.num_channels = config.num_channels - def forward(self, hidden_state): - hidden_state = self.embedder(hidden_state) + def forward(self, pixel_values): + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = self.embedder(pixel_values) return hidden_state diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 61ed3c98871589..d8804d960443df 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -81,9 +81,15 @@ def __init__(self, config: ResNetConfig): config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act ) self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.num_channels = config.num_channels - def forward(self, input: Tensor) -> Tensor: - embedding = self.embedder(input) + def forward(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values) embedding = self.pooler(embedding) return embedding @@ -107,7 +113,7 @@ def forward(self, input: Tensor) -> Tensor: class ResNetBasicLayer(nn.Module): """ - A classic ResNet's residual layer composed by a two `3x3` convolutions. + A classic ResNet's residual layer composed by two `3x3` convolutions. """ def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"): @@ -133,10 +139,10 @@ def forward(self, hidden_state): class ResNetBottleNeckLayer(nn.Module): """ - A classic ResNet's bottleneck layer composed by a three `3x3` convolutions. + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` - convolution faster. The last `1x1` convolution remap the reduced features to `out_channels`. + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. """ def __init__( diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index f8f9342eed694c..91ec499a70a74d 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -372,7 +372,7 @@ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int] return input_feature -# Copied from transformers.models.convnext.modeling_convnext.drop_path +# Copied from transformers.models.beit.modeling_beit.drop_path def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). @@ -393,7 +393,7 @@ def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True): return output -# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath with ConvNext->Swin +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin class SwinDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index 46662596612595..ab07abdc120deb 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -52,19 +52,6 @@ _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" -# Inspired by -# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py -# From PyTorch internals -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - - -# Based on timm implementation, which can be found here: -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - - class TFViTEmbeddings(tf.keras.layers.Layer): """ Construct the CLS token, position and patch embeddings. @@ -74,7 +61,7 @@ class TFViTEmbeddings(tf.keras.layers.Layer): def __init__(self, config: ViTConfig, **kwargs): super().__init__(**kwargs) - self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings") + self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings") self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) self.config = config @@ -103,19 +90,21 @@ def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: """ batch_size, seq_len, dim = shape_list(embeddings) - npatch = seq_len - 1 + num_patches = seq_len - 1 - _, N, _ = shape_list(self.position_embeddings) - N -= 1 + _, num_positions, _ = shape_list(self.position_embeddings) + num_positions -= 1 - if npatch == N and height == width: + if num_patches == num_positions and height == width: return self.position_embeddings class_pos_embed = self.position_embeddings[:, :1] patch_pos_embed = self.position_embeddings[:, 1:] h0 = height // self.config.patch_size w0 = width // self.config.patch_size patch_pos_embed = tf.image.resize( - images=tf.reshape(patch_pos_embed, shape=(1, int(math.sqrt(N)), int(math.sqrt(N)), dim)), + images=tf.reshape( + patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ), size=(h0, w0), method="bicubic", ) @@ -150,27 +139,35 @@ def call( # Based on timm implementation, which can be found here: # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -class TFPatchEmbeddings(tf.keras.layers.Layer): +class TFViTPatchEmbeddings(tf.keras.layers.Layer): """ - Image to Patch Embedding. + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ def __init__(self, config: ViTConfig, **kwargs): super().__init__(**kwargs) - image_size = to_2tuple(config.image_size) - patch_size = to_2tuple(config.patch_size) + image_size, patch_size, num_channels, hidden_size = ( + config.image_size, + config.patch_size, + config.num_channels, + config.hidden_size, + ) + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_patches = num_patches - self.num_channels = config.num_channels - self.embed_dim = config.hidden_size + self.num_channels = num_channels self.config = config self.projection = tf.keras.layers.Conv2D( - filters=self.embed_dim, + filters=hidden_size, kernel_size=patch_size, - strides=self.patch_size, + strides=patch_size, padding="valid", data_format="channels_last", use_bias=True, @@ -201,9 +198,9 @@ def call( # Change the 2D spatial dimensions to a single temporal dimension. # shape = (batch_size, num_patches, out_channels=embed_dim) num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) - x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) + embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) - return x + return embeddings class TFViTSelfAttention(tf.keras.layers.Layer): diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index efb59b2217c2a2..747b440d3ac783 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -327,10 +327,6 @@ def __init__(self, config: ViTMAEConfig, **kwargs): def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) if getattr(height, "numpy", None) and getattr(width, "numpy", None): if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( From 96537772987fe027ebccbd64fe6dff4febf590e7 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Wed, 22 Jun 2022 14:19:17 +0000 Subject: [PATCH 08/16] Fix ViTMAE test --- tests/models/vit_mae/test_modeling_tf_vit_mae.py | 2 +- tests/models/vit_mae/test_modeling_vit_mae.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index 06c965a4a930f3..465b30c5cde766 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -145,7 +145,7 @@ def create_and_check_for_pretraining(self, config, pixel_values, labels): pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values, training=False) expected_num_channels = self.patch_size**2 - self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index bddf2c115a136b..5a48d253a385ee 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -143,7 +143,7 @@ def create_and_check_for_pretraining(self, config, pixel_values, labels): pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) expected_num_channels = self.patch_size**2 - self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, expected_num_channels)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() From 57ff8207c18541784dbeb133b0f3afac5afa6ce4 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Wed, 22 Jun 2022 15:12:59 +0000 Subject: [PATCH 09/16] Add channel check for TF models --- .../models/data2vec/modeling_tf_data2vec_vision.py | 4 ++++ src/transformers/models/vit/modeling_tf_vit.py | 4 ++++ src/transformers/models/vit_mae/modeling_tf_vit_mae.py | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index 811580f4bbb2aa..8b45add4c4f194 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -227,6 +227,10 @@ def __init__(self, config: Data2VecVisionConfig, **kwargs): def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if getattr(height, "numpy", None) and getattr(width, "numpy", None): if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index ab07abdc120deb..da78193e96476b 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -180,6 +180,10 @@ def call( self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False ) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if not interpolate_pos_encoding: if getattr(height, "numpy", None) and getattr(width, "numpy", None): if height != self.image_size[0] or width != self.image_size[1]: diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index 747b440d3ac783..7c1dd4d4a08bea 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -327,6 +327,10 @@ def __init__(self, config: ViTMAEConfig, **kwargs): def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) if getattr(height, "numpy", None) and getattr(width, "numpy", None): if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( From e10b6cac6a37403c76b74bfb6bb4ba4ce9803b09 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Thu, 23 Jun 2022 09:57:09 +0000 Subject: [PATCH 10/16] Add proper channel check for TF models --- src/transformers/models/beit/modeling_beit.py | 4 +++- .../models/convnext/modeling_tf_convnext.py | 9 +++++++++ .../models/data2vec/modeling_tf_data2vec_vision.py | 11 ++++++----- src/transformers/models/swin/modeling_tf_swin.py | 2 +- src/transformers/models/vit/modeling_tf_vit.py | 2 +- .../models/vit_mae/modeling_tf_vit_mae.py | 11 ++++++----- 6 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 2199be718a2222..b266a2d6096387 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -171,7 +171,9 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo class BeitPatchEmbeddings(nn.Module): """ - Image to Patch Embedding. + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ def __init__(self, config): diff --git a/src/transformers/models/convnext/modeling_tf_convnext.py b/src/transformers/models/convnext/modeling_tf_convnext.py index 3446925a072c89..58f4c3bba98489 100644 --- a/src/transformers/models/convnext/modeling_tf_convnext.py +++ b/src/transformers/models/convnext/modeling_tf_convnext.py @@ -20,6 +20,8 @@ import numpy as np import tensorflow as tf +from transformers import shape_list + from ...activations_tf import get_tf_activation from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput from ...modeling_tf_utils import ( @@ -77,11 +79,18 @@ def __init__(self, config, **kwargs): bias_initializer="zeros", ) self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm") + self.num_channels = config.num_channels def call(self, pixel_values): if isinstance(pixel_values, dict): pixel_values = pixel_values["pixel_values"] + num_channels = shape_list(pixel_values)[1] + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. # So change the input format from `NCHW` to `NHWC`. # shape = (batch_size, in_height, in_width, in_channels=num_channels) diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index 8b45add4c4f194..b325b5c515502b 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -227,11 +227,12 @@ def __init__(self, config: Data2VecVisionConfig, **kwargs): def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - if getattr(height, "numpy", None) and getattr(width, "numpy", None): + if tf.executing_eagerly(): + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the" + " configuration." + ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model" diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 557b057f7caedc..22166fccd46af5 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -352,7 +352,7 @@ def maybe_pad(self, pixel_values: tf.Tensor, height: int, width: int) -> tf.Tens def call(self, pixel_values: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor, Tuple[int, int]]: _, num_channels, height, width = shape_list(pixel_values) - if num_channels != self.num_channels: + if tf.executing_eagerly() and num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index da78193e96476b..327bfd7a34cd23 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -185,7 +185,7 @@ def call( "Make sure that the channel dimension of the pixel values match with the one set in the configuration." ) if not interpolate_pos_encoding: - if getattr(height, "numpy", None) and getattr(width, "numpy", None): + if tf.executing_eagerly(): if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model" diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index 7c1dd4d4a08bea..952f3d22a28b31 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -327,11 +327,12 @@ def __init__(self, config: ViTMAEConfig, **kwargs): def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: batch_size, num_channels, height, width = shape_list(pixel_values) - if tf.executing_eagerly() and num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - ) - if getattr(height, "numpy", None) and getattr(width, "numpy", None): + if tf.executing_eagerly(): + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the" + " configuration." + ) if height != self.image_size[0] or width != self.image_size[1]: raise ValueError( f"Input image size ({height}*{width}) doesn't match model" From 4b947b3552db47f6549ccd7cb75441680cf3bb8d Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Thu, 23 Jun 2022 10:15:38 +0000 Subject: [PATCH 11/16] Apply suggestion from code review --- .../data2vec/modeling_data2vec_vision.py | 4 ++- .../models/swin/modeling_tf_swin.py | 6 ++--- tests/models/swin/test_modeling_tf_swin.py | 27 ++++++++++++++++--- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 0cf0aae80d6649..6173226da9f683 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -173,7 +173,9 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo # Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision class Data2VecVisionPatchEmbeddings(nn.Module): """ - Image to Patch Embedding. + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. """ def __init__(self, config): diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 22166fccd46af5..25fb06ae101028 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -1255,7 +1255,7 @@ class TFSwinDecoder(tf.keras.layers.Layer): def __init__(self, config: SwinConfig, **kwargs): super().__init__(**kwargs) self.conv2d = tf.keras.layers.Conv2D( - filters=config.encoder_stride**2 * 3, kernel_size=1, strides=1, name="0" + filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0" ) self._block_size = config.encoder_stride self.pixel_shuffle = PixelShuffle(self._block_size, name="1") @@ -1283,8 +1283,8 @@ def call(self, x: tf.Tensor) -> tf.Tensor: @add_start_docstrings( - "Swin Model with a decoder on top for masked image modeling, as proposed in `SimMIM" - " `__.", + "Swin Model with a decoder on top for masked image modeling, as proposed in" + " [SimMIM](https://arxiv.org/abs/2111.09886).", SWIN_START_DOCSTRING, ) class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): diff --git a/tests/models/swin/test_modeling_tf_swin.py b/tests/models/swin/test_modeling_tf_swin.py index aa8d425906b684..fba50375a03f71 100644 --- a/tests/models/swin/test_modeling_tf_swin.py +++ b/tests/models/swin/test_modeling_tf_swin.py @@ -140,6 +140,21 @@ def create_and_check_model(self, config, pixel_values, labels): self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) + def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): + model = TFSwinForMaskedImageModeling(config=config) + result = model(pixel_values) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + ) + + # test greyscale images + config.num_channels = 1 + model = TFSwinForMaskedImageModeling(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size model = TFSwinForImageClassification(config) @@ -191,6 +206,14 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_for_masked_image_modeling(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + @unittest.skip(reason="Swin does not use inputs_embeds") def test_inputs_embeds(self): pass @@ -335,10 +358,6 @@ def test_inputs_requiring_padding(self): config.output_hidden_states = True self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) - def test_for_image_classification(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_image_classification(*config_and_inputs) - @slow def test_model_from_pretrained(self): for model_name in TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: From 1d28d9494c8340e6f783b0d62f12c4a4692bd9cb Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Thu, 23 Jun 2022 11:11:06 +0000 Subject: [PATCH 12/16] Apply suggestions from code review --- src/transformers/models/beit/modeling_beit.py | 8 ++------ .../models/data2vec/modeling_data2vec_vision.py | 8 ++------ .../models/data2vec/modeling_tf_data2vec_vision.py | 8 ++------ src/transformers/models/deit/modeling_deit.py | 8 ++------ src/transformers/models/dpt/modeling_dpt.py | 8 ++------ src/transformers/models/maskformer/modeling_maskformer.py | 8 ++------ src/transformers/models/swin/modeling_swin.py | 8 ++------ src/transformers/models/swin/modeling_tf_swin.py | 8 ++------ src/transformers/models/vilt/modeling_vilt.py | 8 ++------ src/transformers/models/vit/modeling_tf_vit.py | 8 ++------ src/transformers/models/vit/modeling_vit.py | 8 ++------ src/transformers/models/vit_mae/modeling_tf_vit_mae.py | 8 ++------ src/transformers/models/vit_mae/modeling_vit_mae.py | 8 ++------ src/transformers/models/yolos/modeling_yolos.py | 8 ++------ 14 files changed, 28 insertions(+), 84 deletions(-) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index b266a2d6096387..bf4b82da956c6b 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -178,12 +178,8 @@ class BeitPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 6173226da9f683..c92d9c7452f651 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -180,12 +180,8 @@ class Data2VecVisionPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index b325b5c515502b..aeb5143426d5ee 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -197,12 +197,8 @@ def __init__(self, config: Data2VecVisionConfig, **kwargs): super().__init__(**kwargs) self.config = config - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index b73f9c36a38ff0..8f8307499fa479 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -104,12 +104,8 @@ class DeiTPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index c9700ac44fa501..7dfa244805ff77 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -128,12 +128,8 @@ class DPTViTPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 0616adcdfaafab..2932ee6f73de0c 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -555,12 +555,8 @@ class MaskFormerSwinPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.embed_dim, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 91ec499a70a74d..b983b6d5ae8632 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -276,12 +276,8 @@ class SwinPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.embed_dim, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 25fb06ae101028..886b3b1b478f50 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -318,12 +318,8 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.embed_dim, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 77e5487c8cbfa3..09672ef9afb638 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -299,12 +299,8 @@ class ViltPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index 327bfd7a34cd23..1db9cf58032d0f 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -148,12 +148,8 @@ class TFViTPatchEmbeddings(tf.keras.layers.Layer): def __init__(self, config: ViTConfig, **kwargs): super().__init__(**kwargs) - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 5b63a65de2c5dd..89d1514dd674ce 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -148,12 +148,8 @@ class ViTPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) diff --git a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py index 952f3d22a28b31..d5fbecabd62dc7 100644 --- a/src/transformers/models/vit_mae/modeling_tf_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -299,12 +299,8 @@ class TFViTMAEPatchEmbeddings(tf.keras.layers.Layer): def __init__(self, config: ViTMAEConfig, **kwargs): super().__init__(**kwargs) - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 34403787d99a69..0667bdd73c5545 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -288,12 +288,8 @@ class ViTMAEPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index af6dc35ab00ce4..0d640212c21be2 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -225,12 +225,8 @@ class YolosPatchEmbeddings(nn.Module): def __init__(self, config): super().__init__() - image_size, patch_size, num_channels, hidden_size = ( - config.image_size, - config.patch_size, - config.num_channels, - config.hidden_size, - ) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) From 84ce11828fc9b9c22a90955718fc58fd281f91d6 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Thu, 23 Jun 2022 11:30:41 +0000 Subject: [PATCH 13/16] Add channel check for Flax models, apply suggestion --- .../models/beit/modeling_flax_beit.py | 6 ++++++ src/transformers/models/vit/modeling_flax_vit.py | 16 +++++++++++----- tests/models/yolos/test_modeling_yolos.py | 2 +- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py index b8ef84c0cf6b42..94e6eac7a32e0f 100644 --- a/src/transformers/models/beit/modeling_flax_beit.py +++ b/src/transformers/models/beit/modeling_flax_beit.py @@ -171,6 +171,7 @@ class FlaxBeitPatchEmbeddings(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): + self.num_channels = self.config.num_channels image_size = self.config.image_size patch_size = self.config.patch_size num_patches = (image_size // patch_size) * (image_size // patch_size) @@ -187,6 +188,11 @@ def setup(self): ) def __call__(self, pixel_values): + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) embeddings = self.projection(pixel_values) batch_size, _, _, channels = embeddings.shape return jnp.reshape(embeddings, (batch_size, -1, channels)) diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py index f6e7044057361b..b33114b7a4018c 100644 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -84,7 +84,7 @@ """ -class FlaxPatchEmbeddings(nn.Module): +class FlaxViTPatchEmbeddings(nn.Module): config: ViTConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -94,6 +94,7 @@ def setup(self): patch_size = self.config.patch_size num_patches = (image_size // patch_size) * (image_size // patch_size) self.num_patches = num_patches + self.num_channels = self.config.num_channels self.projection = nn.Conv( self.config.hidden_size, kernel_size=(patch_size, patch_size), @@ -104,9 +105,14 @@ def setup(self): ) def __call__(self, pixel_values): - x = self.projection(pixel_values) - batch_size, _, _, channels = x.shape - return jnp.reshape(x, (batch_size, -1, channels)) + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) class FlaxViTEmbeddings(nn.Module): @@ -117,7 +123,7 @@ class FlaxViTEmbeddings(nn.Module): def setup(self): self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) - self.patch_embeddings = FlaxPatchEmbeddings(self.config, dtype=self.dtype) + self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype) num_patches = self.patch_embeddings.num_patches self.position_embeddings = self.param( "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py index 22cda4f7a81da2..7f8693aa607118 100644 --- a/tests/models/yolos/test_modeling_yolos.py +++ b/tests/models/yolos/test_modeling_yolos.py @@ -86,7 +86,7 @@ def __init__( self.num_detection_tokens = num_detection_tokens # we set the expected sequence length (which is used in several tests) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens - num_patches = (image_size[0] // patch_size) ** 2 + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.expected_seq_len = num_patches + 1 + self.num_detection_tokens def prepare_config_and_inputs(self): From aee248b60f79fda5f5c057c02aeffa2bd11c7695 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Thu, 23 Jun 2022 14:27:39 +0000 Subject: [PATCH 14/16] Fix bug --- tests/models/yolos/test_modeling_yolos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/yolos/test_modeling_yolos.py b/tests/models/yolos/test_modeling_yolos.py index 7f8693aa607118..1d07e50ce7b20f 100644 --- a/tests/models/yolos/test_modeling_yolos.py +++ b/tests/models/yolos/test_modeling_yolos.py @@ -86,7 +86,7 @@ def __init__( self.num_detection_tokens = num_detection_tokens # we set the expected sequence length (which is used in several tests) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size) self.expected_seq_len = num_patches + 1 + self.num_detection_tokens def prepare_config_and_inputs(self): From 465ab11ad8dcb84d94a663497189a835901ff3ed Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Fri, 24 Jun 2022 08:30:15 +0000 Subject: [PATCH 15/16] Add tests for greyscale images --- .../models/beit/modeling_flax_beit.py | 2 +- .../models/vit/modeling_flax_vit.py | 2 +- tests/models/beit/test_modeling_beit.py | 10 +++++++ tests/models/beit/test_modeling_flax_beit.py | 7 +++++ tests/models/deit/test_modeling_deit.py | 10 +++++++ tests/models/swin/test_modeling_swin.py | 10 +++++++ tests/models/swin/test_modeling_tf_swin.py | 7 +++++ tests/models/vit/test_modeling_flax_vit.py | 26 ++++++++++++++++--- tests/models/vit/test_modeling_tf_vit.py | 7 +++++ tests/models/vit/test_modeling_vit.py | 10 +++++++ 10 files changed, 86 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py index 94e6eac7a32e0f..225fb280af4a01 100644 --- a/src/transformers/models/beit/modeling_flax_beit.py +++ b/src/transformers/models/beit/modeling_flax_beit.py @@ -609,7 +609,7 @@ def __init__( ): module = self.module_class(config=config, dtype=dtype, **kwargs) if input_shape is None: - input_shape = (1, config.image_size, config.image_size, 3) + input_shape = (1, config.image_size, config.image_size, config.num_channels) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py index b33114b7a4018c..7a438abb032938 100644 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -426,7 +426,7 @@ def __init__( ): module = self.module_class(config=config, dtype=dtype, **kwargs) if input_shape is None: - input_shape = (1, config.image_size, config.image_size, 3) + input_shape = (1, config.image_size, config.image_size, config.num_channels) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 8c9202c34b1025..b06bd5c21e4059 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -153,6 +153,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values, labels=labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + # test greyscale images + config.num_channels = 1 + model = BeitForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels): config.num_labels = self.num_labels model = BeitForSemanticSegmentation(config) diff --git a/tests/models/beit/test_modeling_flax_beit.py b/tests/models/beit/test_modeling_flax_beit.py index ff10e536bf48a0..b37dd5bf36b403 100644 --- a/tests/models/beit/test_modeling_flax_beit.py +++ b/tests/models/beit/test_modeling_flax_beit.py @@ -120,6 +120,13 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + # test greyscale images + config.num_channels = 1 + model = FlaxBeitForImageClassification(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index b163294cedf8c4..27f92c2d976a61 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -158,6 +158,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values, labels=labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + # test greyscale images + config.num_channels = 1 + model = DeiTForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( diff --git a/tests/models/swin/test_modeling_swin.py b/tests/models/swin/test_modeling_swin.py index 33080ed47ffd50..5e07efa2a3dc00 100644 --- a/tests/models/swin/test_modeling_swin.py +++ b/tests/models/swin/test_modeling_swin.py @@ -169,6 +169,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values, labels=labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + # test greyscale images + config.num_channels = 1 + model = SwinForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( diff --git a/tests/models/swin/test_modeling_tf_swin.py b/tests/models/swin/test_modeling_tf_swin.py index fba50375a03f71..94ca1ac2ba86bd 100644 --- a/tests/models/swin/test_modeling_tf_swin.py +++ b/tests/models/swin/test_modeling_tf_swin.py @@ -161,6 +161,13 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values, labels=labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + # test greyscale images + config.num_channels = 1 + model = TFSwinForImageClassification(config) + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs diff --git a/tests/models/vit/test_modeling_flax_vit.py b/tests/models/vit/test_modeling_flax_vit.py index 56fe28d41bafd8..611f9364885450 100644 --- a/tests/models/vit/test_modeling_flax_vit.py +++ b/tests/models/vit/test_modeling_flax_vit.py @@ -91,8 +91,7 @@ def prepare_config_and_inputs(self): return config, pixel_values - def create_and_check_model(self, config, pixel_values, labels): - + def create_and_check_model(self, config, pixel_values): model = FlaxViTModel(config=config) result = model(pixel_values) # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) @@ -101,6 +100,19 @@ def create_and_check_model(self, config, pixel_values, labels): num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + def create_and_check_for_image_classification(self, config, pixel_values): + config.num_labels = self.type_sequence_label_size + model = FlaxViTForImageClassification(config=config) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + # test greyscale images + config.num_channels = 1 + model = FlaxViTForImageClassification(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -123,7 +135,15 @@ def setUp(self) -> None: def test_config(self): self.config_tester.run_common_tests() - # We neeed to override this test because ViT's forward signature is different than text models. + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + # We need to override this test because ViT's forward signature is different than text models. def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/vit/test_modeling_tf_vit.py b/tests/models/vit/test_modeling_tf_vit.py index 096558091ac820..7f452886f150a3 100644 --- a/tests/models/vit/test_modeling_tf_vit.py +++ b/tests/models/vit/test_modeling_tf_vit.py @@ -133,6 +133,13 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values, interpolate_pos_encoding=True, training=False) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + # test greyscale images + config.num_channels = 1 + model = TFViTForImageClassification(config) + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, pixel_values, labels = config_and_inputs diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index f5bc6cc5d36ceb..21f26f773735b4 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -147,6 +147,16 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels result = model(pixel_values, labels=labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + # test greyscale images + config.num_channels = 1 + model = ViTForImageClassification(config) + model.to(torch_device) + model.eval() + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( From 08efa9de7c2f4f866f39d4607c3333a2fa92de60 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Fri, 24 Jun 2022 09:12:05 +0000 Subject: [PATCH 16/16] Add test for interpolation of pos encodigns --- src/transformers/models/vit/modeling_vit.py | 6 ++--- tests/models/vit/test_modeling_vit.py | 27 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 89d1514dd674ce..7017f232f0e9c7 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -96,10 +96,10 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim).permute( - 0, 3, 1, 2 - ), + patch_pos_embed, scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), mode="bicubic", align_corners=False, diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index 21f26f773735b4..5f856436f3c0b4 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -273,3 +273,30 @@ def test_inference_image_classification_head(self): expected_slice = torch.tensor([-0.2744, 0.8215, -0.0836]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_interpolate_pos_encoding(self): + # ViT models have an `interpolate_pos_encoding` argument in their forward method, + # allowing to interpolate the pre-trained position embeddings in order to use + # the model on higher resolutions. The DINO model by Facebook AI leverages this + # to visualize self-attention on higher resolution images. + model = ViTModel.from_pretrained("facebook/dino-vits8").to(torch_device) + + feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/dino-vits8", size=480) + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values, interpolate_pos_encoding=True) + + # verify the logits + expected_shape = torch.Size((1, 3601, 384)) + self.assertEqual(outputs.last_hidden_state.shape, expected_shape) + + expected_slice = torch.tensor( + [[4.2340, 4.3906, -6.6692], [4.5463, 1.8928, -6.7257], [4.4429, 0.8496, -5.8585]] + ).to(torch_device) + + self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))