Skip to content

Commit

Permalink
Improve vision models (#17731)
Browse files Browse the repository at this point in the history
* Improve vision models

* Add a lot of improvements

* Remove to_2tuple from swin tests

* Fix TF Swin

* Fix more tests

* Fix copies

* Improve more models

* Fix ViTMAE test

* Add channel check for TF models

* Add proper channel check for TF models

* Apply suggestion from code review

* Apply suggestions from code review

* Add channel check for Flax models, apply suggestion

* Fix bug

* Add tests for greyscale images

* Add test for interpolation of pos encodigns

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
  • Loading branch information
NielsRogge and Niels Rogge authored Jun 24, 2022
1 parent 893ab12 commit 0917870
Show file tree
Hide file tree
Showing 39 changed files with 802 additions and 917 deletions.
62 changes: 26 additions & 36 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,7 @@ class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
"""


# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)


# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Expand All @@ -112,16 +102,16 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
argument.
"""
if drop_prob == 0.0 or not training:
return x
return input
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
output = input.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
class BeitDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob: Optional[float] = None) -> None:
Expand Down Expand Up @@ -151,12 +141,7 @@ def __init__(self, config: BeitConfig) -> None:
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
else:
self.mask_token = None
self.patch_embeddings = PatchEmbeddings(
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.hidden_size,
)
self.patch_embeddings = BeitPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
if config.use_absolute_position_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
Expand Down Expand Up @@ -184,38 +169,43 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Bo
return embeddings


# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class PatchEmbeddings(nn.Module):
class BeitPatchEmbeddings(nn.Module):
"""
Image to Patch Embedding.
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""

def __init__(
self, image_size: int = 224, patch_size: int = 16, num_channels: int = 3, embed_dim: int = 768
) -> None:
def __init__(self, config):
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size

image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches
self.patch_shape = patch_shape

self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)

return x
return embeddings


class BeitSelfAttention(nn.Module):
Expand Down Expand Up @@ -393,7 +383,7 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop
self.intermediate = BeitIntermediate(config)
self.output = BeitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.drop_path = BeitDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

init_values = config.layer_scale_init_value
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/beit/modeling_flax_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class FlaxBeitPatchEmbeddings(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation

def setup(self):
self.num_channels = self.config.num_channels
image_size = self.config.image_size
patch_size = self.config.patch_size
num_patches = (image_size // patch_size) * (image_size // patch_size)
Expand All @@ -187,6 +188,11 @@ def setup(self):
)

def __call__(self, pixel_values):
num_channels = pixel_values.shape[-1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.projection(pixel_values)
batch_size, _, _, channels = embeddings.shape
return jnp.reshape(embeddings, (batch_size, -1, channels))
Expand Down Expand Up @@ -603,7 +609,7 @@ def __init__(
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
Expand Down
37 changes: 24 additions & 13 deletions src/transformers/models/convnext/modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,36 +53,41 @@
]


# Stochastic depth implementation
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion:
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return x
return input
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
output = input.div(keep_prob) * random_tensor
return output


# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
class ConvNextDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob=None):
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob

def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)


class ConvNextLayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
Expand Down Expand Up @@ -122,8 +127,14 @@ def __init__(self, config):
config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
)
self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
self.num_channels = config.num_channels

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.patch_embeddings(pixel_values)
embeddings = self.layernorm(embeddings)
return embeddings
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/convnext/modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import numpy as np
import tensorflow as tf

from transformers import shape_list

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
from ...modeling_tf_utils import (
Expand Down Expand Up @@ -77,11 +79,18 @@ def __init__(self, config, **kwargs):
bias_initializer="zeros",
)
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
self.num_channels = config.num_channels

def call(self, pixel_values):
if isinstance(pixel_values, dict):
pixel_values = pixel_values["pixel_values"]

num_channels = shape_list(pixel_values)[1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)

# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
Expand Down
31 changes: 18 additions & 13 deletions src/transformers/models/cvt/modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,41 @@ class BaseModelOutputWithCLSToken(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


# Copied from transformers.models.convnext.modeling_convnext.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input, drop_prob: float = 0.0, training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the
DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop
Connect' is a different form of dropout in a separate paper... See discussion:
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and
argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return x
return input
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
output = input.div(keep_prob) * random_tensor
return output


# Copied from transformers.models.convnext.modeling_convnext.ConvNextDropPath
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
class CvtDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob=None):
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob

def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training)

def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)


class CvtEmbeddings(nn.Module):
"""
Expand Down
Loading

0 comments on commit 0917870

Please sign in to comment.