From eed3fa5e0c40befd7e5bc6c6f27175db160f30fe Mon Sep 17 00:00:00 2001 From: Denisa Roberts Date: Wed, 31 May 2023 05:43:12 -0400 Subject: [PATCH] Add TensorFlow implementation of EfficientFormer (#22620) * Add tf code for efficientformer * Fix return dict bug - return last hidden state after last stage * Fix corresponding return dict bug * Override test tol * Change default values of training to False * Set training to default False X3 * Rm axis from ln * Set init in dense projection * Rm debug stuff * Make style; all tests pass. * Modify year to 2023 * Fix attention biases codes * Update the shape list logic * Add a batch norm eps config * Remove extract comments in test files * Add conditional attn and hidden states return for serving output * Change channel dim checking logic * Add exception for withteacher model in training mode * Revert layer count for now * Add layer count for conditional layer naming * Transpose for conv happens only in main layer * Make tests smaller * Make style * Update doc * Rm from_pt * Change to actual expect image class label * Remove stray print in tests * Update image processor test * Remove the old serving output logic * Make style * Make style * Complete test --- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/efficientformer.mdx | 17 +- src/transformers/__init__.py | 16 + .../models/auto/modeling_tf_auto.py | 5 + .../models/efficientformer/__init__.py | 35 +- .../configuration_efficientformer.py | 8 +- .../modeling_efficientformer.py | 37 +- .../modeling_tf_efficientformer.py | 986 ++++++++++++++++++ src/transformers/utils/dummy_tf_objects.py | 31 + .../test_modeling_efficientformer.py | 29 +- .../test_modeling_tf_efficientformer.py | 393 +++++++ utils/documentation_tests.txt | 1 + 12 files changed, 1537 insertions(+), 23 deletions(-) create mode 100644 src/transformers/models/efficientformer/modeling_tf_efficientformer.py create mode 100644 tests/models/efficientformer/test_modeling_tf_efficientformer.py diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 63f1f8d7f4a90c..1813efbd09a88f 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -314,7 +314,7 @@ Flax), PyTorch, and/or TensorFlow. | DonutSwin | ❌ | ❌ | ✅ | ❌ | ❌ | | DPR | ✅ | ✅ | ✅ | ✅ | ❌ | | DPT | ❌ | ❌ | ✅ | ❌ | ❌ | -| EfficientFormer | ❌ | ❌ | ✅ | ❌ | ❌ | +| EfficientFormer | ❌ | ❌ | ✅ | ✅ | ❌ | | EfficientNet | ❌ | ❌ | ✅ | ❌ | ❌ | | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | | Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/efficientformer.mdx b/docs/source/en/model_doc/efficientformer.mdx index 2b512932ef82ba..0ef8cfb53f808e 100644 --- a/docs/source/en/model_doc/efficientformer.mdx +++ b/docs/source/en/model_doc/efficientformer.mdx @@ -37,7 +37,7 @@ EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work pr reach extremely low latency on mobile devices while maintaining high performance.* This model was contributed by [novice03](https://huggingface.co/novice03) and [Bearnardd](https://huggingface.co/Bearnardd). -The original code can be found [here](https://github.com/snap-research/EfficientFormer). +The original code can be found [here](https://github.com/snap-research/EfficientFormer). The TensorFlow version of this model was added by [D-Roberts](https://huggingface.co/D-Roberts). ## Documentation resources @@ -66,3 +66,18 @@ The original code can be found [here](https://github.com/snap-research/Efficient [[autodoc]] EfficientFormerForImageClassificationWithTeacher - forward + +## TFEfficientFormerModel + +[[autodoc]] TFEfficientFormerModel + - call + +## TFEfficientFormerForImageClassification + +[[autodoc]] TFEfficientFormerForImageClassification + - call + +## TFEfficientFormerForImageClassificationWithTeacher + +[[autodoc]] TFEfficientFormerForImageClassificationWithTeacher + - call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 302204b74d222f..0f0b3b4765cbac 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3151,6 +3151,15 @@ "TFDPRReader", ] ) + _import_structure["models.efficientformer"].extend( + [ + "TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFEfficientFormerForImageClassification", + "TFEfficientFormerForImageClassificationWithTeacher", + "TFEfficientFormerModel", + "TFEfficientFormerPreTrainedModel", + ] + ) _import_structure["models.electra"].extend( [ "TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -6487,6 +6496,13 @@ TFDPRQuestionEncoder, TFDPRReader, ) + from .models.efficientformer import ( + TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFEfficientFormerForImageClassification, + TFEfficientFormerForImageClassificationWithTeacher, + TFEfficientFormerModel, + TFEfficientFormerPreTrainedModel, + ) from .models.electra import ( TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, TFElectraForMaskedLM, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index bfc29f2dd35f67..bd86431c8cb256 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -47,6 +47,7 @@ ("deit", "TFDeiTModel"), ("distilbert", "TFDistilBertModel"), ("dpr", "TFDPRQuestionEncoder"), + ("efficientformer", "TFEfficientFormerModel"), ("electra", "TFElectraModel"), ("esm", "TFEsmModel"), ("flaubert", "TFFlaubertModel"), @@ -202,6 +203,10 @@ ("cvt", "TFCvtForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"), ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), + ( + "efficientformer", + ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"), + ), ("mobilevit", "TFMobileViTForImageClassification"), ("regnet", "TFRegNetForImageClassification"), ("resnet", "TFResNetForImageClassification"), diff --git a/src/transformers/models/efficientformer/__init__.py b/src/transformers/models/efficientformer/__init__.py index ea7bcdffd45931..25d60d1ee765ef 100644 --- a/src/transformers/models/efficientformer/__init__.py +++ b/src/transformers/models/efficientformer/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_tf_available, + is_torch_available, + is_vision_available, +) _import_structure = { @@ -45,6 +51,20 @@ "EfficientFormerPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_efficientformer"] = [ + "TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFEfficientFormerForImageClassification", + "TFEfficientFormerForImageClassificationWithTeacher", + "TFEfficientFormerModel", + "TFEfficientFormerPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_efficientformer import EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, EfficientFormerConfig @@ -69,6 +89,19 @@ EfficientFormerModel, EfficientFormerPreTrainedModel, ) + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_efficientformer import ( + TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFEfficientFormerForImageClassification, + TFEfficientFormerForImageClassificationWithTeacher, + TFEfficientFormerModel, + TFEfficientFormerPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/efficientformer/configuration_efficientformer.py b/src/transformers/models/efficientformer/configuration_efficientformer.py index 5f30664ff325a0..fecb90a886e8eb 100644 --- a/src/transformers/models/efficientformer/configuration_efficientformer.py +++ b/src/transformers/models/efficientformer/configuration_efficientformer.py @@ -52,7 +52,7 @@ class EfficientFormerConfig(PretrainedConfig): The size of the key in meta3D block. attention_ratio (`int`, *optional*, defaults to 4): Ratio of the dimension of the query and value to the dimension of the key in MSHA block - resolution (`int`, *optional*, defaults to 5) + resolution (`int`, *optional*, defaults to 7) Size of each patch num_hidden_layers (`int`, *optional*, defaults to 5): Number of hidden layers in the Transformer encoder. @@ -91,6 +91,8 @@ class EfficientFormerConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to `224`): + The size (resolution) of each image. Example: @@ -136,6 +138,8 @@ def __init__( hidden_act: str = "gelu", initializer_range: float = 0.02, layer_norm_eps: float = 1e-12, + image_size: int = 224, + batch_norm_eps: float = 1e-05, **kwargs, ) -> None: super().__init__(**kwargs) @@ -165,3 +169,5 @@ def __init__( self.distillation = distillation self.use_layer_scale = use_layer_scale self.layer_scale_init_value = layer_scale_init_value + self.image_size = image_size + self.batch_norm_eps = batch_norm_eps diff --git a/src/transformers/models/efficientformer/modeling_efficientformer.py b/src/transformers/models/efficientformer/modeling_efficientformer.py index b6264e60c8b899..cd38b2d813e1fb 100644 --- a/src/transformers/models/efficientformer/modeling_efficientformer.py +++ b/src/transformers/models/efficientformer/modeling_efficientformer.py @@ -43,7 +43,7 @@ # Base docstring _CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300" -_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] +_EXPECTED_OUTPUT_SHAPE = [1, 49, 448] # Image classification docstring _IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300" @@ -73,7 +73,7 @@ def __init__(self, config: EfficientFormerConfig, num_channels: int, embed_dim: stride=config.downsample_stride, padding=config.downsample_pad, ) - self.norm = nn.BatchNorm2d(embed_dim) if apply_norm else nn.Identity() + self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity() def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape @@ -157,10 +157,10 @@ def __init__(self, config: EfficientFormerConfig, out_channels: int): super().__init__() self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1) - self.batchnorm_before = nn.BatchNorm2d(out_channels // 2) + self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps) self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1) - self.batchnorm_after = nn.BatchNorm2d(out_channels) + self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps) self.activation = nn.ReLU() @@ -224,24 +224,24 @@ def __init__( hidden_features = hidden_features or in_features self.convolution1 = nn.Conv2d(in_features, hidden_features, 1) - self.actvation = ACT2FN[config.hidden_act] + self.activation = ACT2FN[config.hidden_act] self.convolution2 = nn.Conv2d(hidden_features, out_features, 1) self.dropout = nn.Dropout(drop) - self.batchnorm_before = nn.BatchNorm2d(hidden_features) - self.batchnorm_after = nn.BatchNorm2d(out_features) + self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps) + self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.convolution1(hidden_state) hidden_state = self.batchnorm_before(hidden_state) - hidden_state = self.actvation(hidden_state) + hidden_state = self.activation(hidden_state) hidden_state = self.dropout(hidden_state) hidden_state = self.convolution2(hidden_state) hidden_state = self.batchnorm_after(hidden_state) - hidden_state = self.dropout(hidden_state) + return hidden_state @@ -266,7 +266,7 @@ def drop_path(input, drop_prob: float = 0.0, training: bool = False): return output -# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->EfficientFormer class EfficientFormerDropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" @@ -301,8 +301,10 @@ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0 attention_ratio=config.attention_ratio, resolution=config.resolution, ) - self.layernorm1 = nn.LayerNorm(dim) - self.layernorm2 = nn.LayerNorm(dim) + + self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps) + mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim) @@ -346,15 +348,20 @@ def __init__(self, config: EfficientFormerConfig): def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]: all_attention_outputs = () if output_attentions else None + for layer_module in self.blocks: if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] + hidden_states = layer_module(hidden_states, output_attentions) + if output_attentions: all_attention_outputs = all_attention_outputs + (hidden_states[1],) + if output_attentions: outputs = (hidden_states[0],) + all_attention_outputs return outputs + return hidden_states @@ -379,6 +386,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: if self.use_layer_scale: layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs) + layer_output = layer_output + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output) ) @@ -398,6 +406,7 @@ def __init__(self, config: EfficientFormerConfig, stage_idx: int): drop_paths = [ config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers) ] + self.blocks = nn.ModuleList( [ EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path) @@ -446,6 +455,7 @@ def __init__(self, config: EfficientFormerConfig): for i in range(num_intermediate_stages) ] intermediate_stages = [] + for i in range(num_intermediate_stages): intermediate_stages.append(EfficientFormerIntermediateStage(config, i)) if downsamples[i]: @@ -475,6 +485,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) layer_output = self.last_stage(hidden_states, output_attentions=output_attentions) + if output_attentions: all_self_attentions = all_self_attentions + layer_output[1:] @@ -482,7 +493,7 @@ def forward( all_hidden_states = all_hidden_states + (layer_output[0],) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=layer_output[0], diff --git a/src/transformers/models/efficientformer/modeling_tf_efficientformer.py b/src/transformers/models/efficientformer/modeling_tf_efficientformer.py new file mode 100644 index 00000000000000..1907af388f9274 --- /dev/null +++ b/src/transformers/models/efficientformer/modeling_tf_efficientformer.py @@ -0,0 +1,986 @@ +# coding=utf-8 +# Copyright 2023 Snapchat Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TensorFlow EfficientFormer model.""" + +import itertools +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFImageClassifierOutput, +) +from ...modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_efficientformer import EfficientFormerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "EfficientFormerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 448] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300" +_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_281" + + +TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "snap-research/efficientformer-l1-300", + # See all EfficientFormer models at https://huggingface.co/models?filter=efficientformer +] + + +class TFEfficientFormerPatchEmbeddings(tf.keras.layers.Layer): + """ + This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels, + height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride] + """ + + def __init__( + self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True, **kwargs + ) -> None: + super().__init__(**kwargs) + self.num_channels = num_channels + + self.padding = tf.keras.layers.ZeroPadding2D(padding=config.downsample_pad) + self.projection = tf.keras.layers.Conv2D( + filters=embed_dim, + kernel_size=config.downsample_patch_size, + strides=config.downsample_stride, + padding="valid", + name="projection", + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.norm = ( + tf.keras.layers.BatchNormalization(axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + if apply_norm + else tf.identity + ) + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + tf.debugging.assert_shapes( + [(pixel_values, (..., None, None, self.num_channels))], + message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.", + ) + embeddings = self.projection(self.padding(pixel_values)) + embeddings = self.norm(embeddings, training=training) + return embeddings + + +class TFEfficientFormerSelfAttention(tf.keras.layers.Layer): + def __init__( + self, + dim: int, + key_dim: int, + num_heads: int, + attention_ratio: int, + resolution: int, + config: EfficientFormerConfig, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_heads = num_heads + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.scale = key_dim**-0.5 + self.total_key_dim = key_dim * num_heads + self.expanded_key_dim = int(attention_ratio * key_dim) + self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads) + hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2 + + self.qkv = tf.keras.layers.Dense( + units=hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="qkv" + ) + self.projection = tf.keras.layers.Dense( + units=dim, kernel_initializer=get_initializer(config.initializer_range), name="projection" + ) + self.resolution = resolution + + def build(self, input_shape: tf.TensorShape) -> None: + points = list(itertools.product(range(self.resolution), range(self.resolution))) + num_points = len(points) + attention_offsets = {} + + idxs = [] + + for point_1 in points: + for point_2 in points: + offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + + self.attention_biases = self.add_weight( + shape=(self.num_heads, len(attention_offsets)), + initializer=tf.keras.initializers.zeros(), + trainable=True, + name="attention_biases", + ) + self.attention_bias_idxs = self.add_weight( + shape=(num_points, num_points), + trainable=False, + dtype=tf.int32, + name="attention_bias_idxs", + ) + + self.attention_bias_idxs.assign(tf.reshape(tf.cast(idxs, dtype=tf.int32), (num_points, num_points))) + + super().build(input_shape) + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + batch_size, sequence_length, *_ = shape_list(hidden_states) + qkv = self.qkv(inputs=hidden_states) + + query_layer, key_layer, value_layer = tf.split( + tf.reshape(tensor=qkv, shape=(batch_size, sequence_length, self.num_heads, -1)), + num_or_size_splits=[self.key_dim, self.key_dim, self.expanded_key_dim], + axis=3, + ) + + query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3]) + key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3]) + value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3]) + + attention_probs = tf.matmul(query_layer, tf.transpose(key_layer, perm=[0, 1, 3, 2])) + scale = tf.cast(self.scale, dtype=attention_probs.dtype) + attention_probs = tf.multiply(attention_probs, scale) + + attention_biases = tf.gather(params=self.attention_biases, indices=self.attention_bias_idxs, axis=1) + attention_probs = attention_probs + attention_biases + attention_probs = stable_softmax(logits=attention_probs, axis=-1) + + context_layer = tf.matmul(attention_probs, value_layer) + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + + context_layer = tf.reshape( + tensor=context_layer, shape=(batch_size, sequence_length, self.total_expanded_key_dim) + ) + context_layer = self.projection(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class TFEfficientFormerConvStem(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, out_channels: int, **kwargs): + super().__init__(**kwargs) + + self.padding = tf.keras.layers.ZeroPadding2D(padding=1) + self.convolution1 = tf.keras.layers.Conv2D( + filters=out_channels // 2, kernel_size=3, strides=2, padding="valid", name="convolution1" + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_before = tf.keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before" + ) + + self.convolution2 = tf.keras.layers.Conv2D( + filters=out_channels, + kernel_size=3, + strides=2, + padding="valid", + name="convolution2", + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_after = tf.keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after" + ) + + self.activation = tf.keras.layers.Activation(activation=tf.keras.activations.relu, name="activation") + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + features = self.batchnorm_before(self.convolution1(self.padding(pixel_values)), training=training) + features = self.activation(features) + features = self.batchnorm_after(self.convolution2(self.padding(features)), training=training) + features = self.activation(features) + return features + + +class TFEfficientFormerPooling(tf.keras.layers.Layer): + def __init__(self, pool_size: int, **kwargs): + super().__init__(**kwargs) + self.pool = tf.keras.layers.AveragePooling2D(pool_size=pool_size, strides=1, padding="same") + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + output = self.pool(hidden_states) + output = output - hidden_states + return output + + +class TFEfficientFormerDenseMlp(tf.keras.layers.Layer): + def __init__( + self, + config: EfficientFormerConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.linear_in = tf.keras.layers.Dense( + units=hidden_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_in" + ) + self.activation = ACT2FN[config.hidden_act] + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + + self.linear_out = tf.keras.layers.Dense( + units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_out" + ) + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.linear_in(inputs=hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.linear_out(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + + return hidden_states + + +class TFEfficientFormerConvMlp(tf.keras.layers.Layer): + def __init__( + self, + config: EfficientFormerConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + drop: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.convolution1 = tf.keras.layers.Conv2D( + filters=hidden_features, + kernel_size=1, + name="convolution1", + padding="valid", + ) + + self.activation = ACT2FN[config.hidden_act] + + self.convolution2 = tf.keras.layers.Conv2D( + filters=out_features, + kernel_size=1, + name="convolution2", + padding="valid", + ) + + self.dropout = tf.keras.layers.Dropout(rate=drop) + + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_before = tf.keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before" + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_after = tf.keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after" + ) + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution1(hidden_state) + hidden_state = self.batchnorm_before(hidden_state, training=training) + hidden_state = self.activation(hidden_state) + hidden_state = self.dropout(hidden_state, training=training) + hidden_state = self.convolution2(hidden_state) + hidden_state = self.batchnorm_after(hidden_state, training=training) + hidden_state = self.dropout(hidden_state, training=training) + return hidden_state + + +# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->EfficientFormer +class TFEfficientFormerDropPath(tf.keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_path, **kwargs): + super().__init__(**kwargs) + self.drop_path = drop_path + + def call(self, x, training=None): + if training: + keep_prob = 1 - self.drop_path + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +class TFEfficientFormerFlat(tf.keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, hidden_states: tf.Tensor) -> Tuple[tf.Tensor]: + batch_size, _, _, in_channels = shape_list(hidden_states) + hidden_states = tf.reshape(hidden_states, shape=[batch_size, -1, in_channels]) + return hidden_states + + +class TFEfficientFormerMeta3D(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): + super().__init__(**kwargs) + + self.token_mixer = TFEfficientFormerSelfAttention( + dim=config.dim, + key_dim=config.key_dim, + num_heads=config.num_attention_heads, + attention_ratio=config.attention_ratio, + resolution=config.resolution, + name="token_mixer", + config=config, + ) + self.dim = dim + self.config = config + + self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm1") + self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm2") + mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) + self.mlp = TFEfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim, name="mlp") + + # Using `layers.Activation` instead of `tf.identity` to better control `training' behavior. + self.drop_path = ( + TFEfficientFormerDropPath(drop_path) + if drop_path > 0.0 + else tf.keras.layers.Activation("linear", name="drop_path") + ) + self.config = config + + def build(self, input_shape: tf.TensorShape): + self.layer_scale_1 = None + self.layer_scale_2 = None + + if self.config.use_layer_scale: + self.layer_scale_1 = self.add_weight( + shape=(self.dim,), + initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_1", + ) + self.layer_scale_2 = self.add_weight( + shape=(self.dim,), + initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_2", + ) + super().build(input_shape) + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + self_attention_outputs = self.token_mixer( + hidden_states=self.layernorm1(hidden_states, training=training), + output_attentions=output_attentions, + training=training, + ) + + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.config.use_layer_scale: + layer_output = hidden_states + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * attention_output, + training=training, + ) + layer_output = layer_output + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0) + * self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training), + training=training, + ) + else: + layer_output = hidden_states + self.drop_path(attention_output, training=training) + layer_output = layer_output + self.drop_path( + self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training), + training=training, + ) + + outputs = (layer_output,) + outputs + + return outputs + + +class TFEfficientFormerMeta3DLayers(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, **kwargs): + super().__init__(**kwargs) + drop_paths = [ + config.drop_path_rate * (block_idx + sum(config.depths[:-1])) + for block_idx in range(config.num_meta3d_blocks) + ] + self.blocks = [ + TFEfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path, name=f"blocks.{i}") + for i, drop_path in enumerate(drop_paths) + ] + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + all_attention_outputs = () if output_attentions else None + + for i, layer_module in enumerate(self.blocks): + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + hidden_states = layer_module( + hidden_states=hidden_states, output_attentions=output_attentions, training=training + ) + if output_attentions: + all_attention_outputs = all_attention_outputs + (hidden_states[1],) + + if output_attentions: + outputs = (hidden_states[0],) + all_attention_outputs + return outputs + + return hidden_states + + +class TFEfficientFormerMeta4D(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): + super().__init__(**kwargs) + pool_size = config.pool_size if config.pool_size is not None else 3 + self.token_mixer = TFEfficientFormerPooling(pool_size=pool_size, name="token_mixer") + self.dim = dim + mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) + self.mlp = TFEfficientFormerConvMlp( + config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob, name="mlp" + ) + + self.drop_path = ( + TFEfficientFormerDropPath(drop_path, name="drop_path") + if drop_path > 0.0 + else tf.keras.layers.Activation("linear", name="drop_path") + ) + self.config = config + + def build(self, input_shape: tf.TensorShape): + self.layer_scale_1 = None + self.layer_scale_2 = None + + if self.config.use_layer_scale: + self.layer_scale_1 = self.add_weight( + shape=(self.dim), + initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_1", + ) + self.layer_scale_2 = self.add_weight( + shape=(self.dim), + initializer=tf.keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_2", + ) + super().build(input_shape) + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]: + outputs = self.token_mixer(hidden_states) + + if self.config.use_layer_scale: + layer_output = hidden_states + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * outputs, + training=training, + ) + + layer_output = layer_output + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0) + * self.mlp(hidden_state=layer_output, training=training), + training=training, + ) + + else: + layer_output = hidden_states + self.drop_path(outputs, training=training) + layer_output = layer_output + self.drop_path( + self.mlp(hidden_state=layer_output, training=training), training=training + ) + + return layer_output + + +class TFEfficientFormerMeta4DLayers(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, stage_idx: int, **kwargs): + super().__init__(**kwargs) + num_layers = ( + config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks + ) + drop_paths = [ + config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers) + ] + + self.blocks = [ + TFEfficientFormerMeta4D( + config=config, dim=config.hidden_sizes[stage_idx], drop_path=drop_paths[i], name=f"blocks.{i}" + ) + for i in range(len(drop_paths)) + ] + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]: + for layer_module in self.blocks: + hidden_states = layer_module(hidden_states=hidden_states, training=training) + return hidden_states + + +class TFEfficientFormerIntermediateStage(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, index: int, **kwargs): + super().__init__(**kwargs) + self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=index, name="meta4D_layers") + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]: + hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training) + return hidden_states + + +class TFEfficientFormerLastStage(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, **kwargs): + super().__init__(**kwargs) + self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=-1, name="meta4D_layers") + self.flat = TFEfficientFormerFlat(name="flat") + self.meta3D_layers = TFEfficientFormerMeta3DLayers(config, name="meta3D_layers") + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training) + hidden_states = self.flat(hidden_states=hidden_states) + hidden_states = self.meta3D_layers( + hidden_states=hidden_states, output_attentions=output_attentions, training=training + ) + + return hidden_states + + +class TFEfficientFormerEncoder(tf.keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + num_intermediate_stages = len(config.depths) - 1 + downsamples = [ + config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1] + for i in range(num_intermediate_stages) + ] + + intermediate_stages = [] + layer_count = -1 + for i in range(num_intermediate_stages): + layer_count += 1 + intermediate_stages.append( + TFEfficientFormerIntermediateStage(config, i, name=f"intermediate_stages.{layer_count}") + ) + if downsamples[i]: + layer_count += 1 + intermediate_stages.append( + TFEfficientFormerPatchEmbeddings( + config, + config.hidden_sizes[i], + config.hidden_sizes[i + 1], + name=f"intermediate_stages.{layer_count}", + ) + ) + self.intermediate_stages = intermediate_stages + self.last_stage = TFEfficientFormerLastStage(config, name="last_stage") + + def call( + self, + hidden_states: tf.Tensor, + output_hidden_states: bool, + output_attentions: bool, + return_dict: bool, + training: bool = False, + ) -> TFBaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + for layer_module in self.intermediate_stages: + hidden_states = layer_module(hidden_states, training=training) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_output = self.last_stage(hidden_states, output_attentions=output_attentions, training=training) + + if output_attentions: + all_self_attentions = all_self_attentions + layer_output[1:] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (layer_output[0],) + + if not return_dict: + return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=layer_output[0], + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@keras_serializable +class TFEfficientFormerMainLayer(tf.keras.layers.Layer): + config_class = EfficientFormerConfig + + def __init__(self, config: EfficientFormerConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + + self.patch_embed = TFEfficientFormerConvStem(config, config.hidden_sizes[0], name="patch_embed") + self.encoder = TFEfficientFormerEncoder(config, name="encoder") + self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[tf.Tensor] = None, + output_hidden_states: Optional[tf.Tensor] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # When running on CPU, tf.keras.layers.Conv2D and tf.keras.layers.AveragePool2D do not + # support channels first NCHW format. A number of blocks contain both. + # So change the input format from (batch_size, num_channels, height, width) to + # (batch_size, height, width, num_channels) here. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + embedding_output = self.patch_embed(pixel_values, training=training) + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output, training=training) + + # Change the hidden states from (batch_size, height, width, num_channels) to + # (batch_size, num_channels, height, width). + # The hidden states are in (batch_size, height, width, num_channels) + # shape after all stages except the MB3D blocks. + if output_hidden_states: + hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1][:-1]]) + ( + encoder_outputs[1][-1], + ) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return TFBaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class TFEfficientFormerPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EfficientFormerConfig + base_model_prefix = "efficientformer" + main_input_name = "pixel_values" + + +EFFICIENTFORMER_START_DOCSTRING = r""" + This model is a TensorFlow + [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular + TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior. + + + Parameters: + config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +EFFICIENTFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values ((`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`EfficientFormerImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.", + EFFICIENTFORMER_START_DOCSTRING, +) +class TFEfficientFormerModel(TFEfficientFormerPreTrainedModel): + def __init__(self, config: EfficientFormerConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple, TFBaseModelOutput]: + outputs = self.efficientformer( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + +@add_start_docstrings( + """ + EfficientFormer Model transformer with an image classification head on top of pooled last hidden state, e.g. for + ImageNet. + """, + EFFICIENTFORMER_START_DOCSTRING, +) +class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: EfficientFormerConfig): + super().__init__(config) + + self.num_labels = config.num_labels + self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") + + # Classifier head + self.classifier = ( + tf.keras.layers.Dense(config.num_labels, name="classifier") + if config.num_labels > 0 + else tf.keras.layers.Activation("linear", name="classifier") + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tf.Tensor, TFImageClassifierOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.efficientformer( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2)) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@dataclass +class TFEfficientFormerForImageClassificationWithTeacherOutput(ModelOutput): + """ + Args: + Output type of [`EfficientFormerForImageClassificationWithTeacher`]. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores as the average of the cls_logits and distillation logits. + cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the + class token). + distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the + distillation token). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + logits: tf.Tensor = None + cls_logits: tf.Tensor = None + distillation_logits: tf.Tensor = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + + +@add_start_docstrings( + """ + EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden + state and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. + + .. warning:: + This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet + supported. + """, + EFFICIENTFORMER_START_DOCSTRING, +) +class TFEfficientFormerForImageClassificationWithTeacher(TFEfficientFormerPreTrainedModel): + def __init__(self, config: EfficientFormerConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") + + # Classifier heads + self.classifier = ( + tf.keras.layers.Dense(config.num_labels, name="classifier") + if config.num_labels > 0 + else tf.keras.layers.Activation("linear", name="classifier") + ) + self.distillation_classifier = ( + tf.keras.layers.Dense(config.num_labels, name="distillation_classifier") + if config.num_labels > 0 + else tf.keras.layers.Activation("linear", name="distillation_classifier") + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFEfficientFormerForImageClassificationWithTeacherOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple, TFEfficientFormerForImageClassificationWithTeacherOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if training: + raise Exception( + "This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported." + ) + + outputs = self.efficientformer( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + cls_logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2)) + distillation_logits = self.distillation_classifier(tf.reduce_mean(sequence_output, axis=-2)) + logits = (cls_logits + distillation_logits) / 2 + + if not return_dict: + output = (logits, cls_logits, distillation_logits) + outputs[1:] + return output + + return TFEfficientFormerForImageClassificationWithTeacherOutput( + logits=logits, + cls_logits=cls_logits, + distillation_logits=distillation_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 658d7f689fce90..4a189174eeebfa 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1099,6 +1099,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFEfficientFormerForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFEfficientFormerForImageClassificationWithTeacher(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFEfficientFormerModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFEfficientFormerPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/efficientformer/test_modeling_efficientformer.py b/tests/models/efficientformer/test_modeling_efficientformer.py index f3e88b1d295980..2d951e4e5b903c 100644 --- a/tests/models/efficientformer/test_modeling_efficientformer.py +++ b/tests/models/efficientformer/test_modeling_efficientformer.py @@ -18,6 +18,7 @@ import inspect import unittest import warnings +from typing import List from transformers import EfficientFormerConfig from transformers.models.auto import get_values @@ -55,15 +56,16 @@ def __init__( self, parent, batch_size: int = 13, - image_size: int = 224, + image_size: int = 64, patch_size: int = 2, - embed_dim: int = 48, # last embed dim of stem + embed_dim: int = 3, num_channels: int = 3, is_training: bool = True, use_labels: bool = True, - hidden_size: int = 448, - num_hidden_layers: int = 7, # For the l1 - num_attention_heads: int = 8, + hidden_size: int = 128, + hidden_sizes=[16, 32, 64, 128], + num_hidden_layers: int = 7, + num_attention_heads: int = 4, intermediate_size: int = 37, hidden_act: str = "gelu", hidden_dropout_prob: float = 0.1, @@ -71,7 +73,11 @@ def __init__( type_sequence_label_size: int = 10, initializer_range: float = 0.02, encoder_stride: int = 2, - num_attention_outputs: int = 1, # For l1 + num_attention_outputs: int = 1, + dim: int = 128, + depths: List[int] = [2, 2, 2, 2], + resolution: int = 2, + mlp_expansion_ratio: int = 2, ): self.parent = parent self.batch_size = batch_size @@ -93,6 +99,11 @@ def __init__( self.num_attention_outputs = num_attention_outputs self.embed_dim = embed_dim self.seq_length = embed_dim + 1 + self.resolution = resolution + self.depths = depths + self.hidden_sizes = hidden_sizes + self.dim = dim + self.mlp_expansion_ratio = mlp_expansion_ratio def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -119,6 +130,11 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, encoder_stride=self.encoder_stride, + resolution=self.resolution, + depths=self.depths, + hidden_sizes=self.hidden_sizes, + dim=self.dim, + mlp_expansion_ratio=self.mlp_expansion_ratio, ) def create_and_check_model(self, config, pixel_values, labels): @@ -379,6 +395,7 @@ def test_attention_outputs(self): encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes diff --git a/tests/models/efficientformer/test_modeling_tf_efficientformer.py b/tests/models/efficientformer/test_modeling_tf_efficientformer.py new file mode 100644 index 00000000000000..5301aee561b0f0 --- /dev/null +++ b/tests/models/efficientformer/test_modeling_tf_efficientformer.py @@ -0,0 +1,393 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TensorFlow EfficientFormer model. """ + +import inspect +import unittest +from typing import List + +import numpy as np + +from transformers import EfficientFormerConfig +from transformers.testing_utils import require_tf, require_vision, slow +from transformers.utils import cached_property, is_tf_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers import ( + TFEfficientFormerForImageClassification, + TFEfficientFormerForImageClassificationWithTeacher, + TFEfficientFormerModel, + ) + from transformers.models.efficientformer.modeling_tf_efficientformer import ( + TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + ) + + +if is_vision_available(): + from PIL import Image + + from transformers import EfficientFormerImageProcessor + + +class TFEfficientFormerModelTester: + def __init__( + self, + parent, + batch_size: int = 13, + image_size: int = 64, + patch_size: int = 2, + embed_dim: int = 3, + num_channels: int = 3, + is_training: bool = True, + use_labels: bool = True, + hidden_size: int = 128, + hidden_sizes=[16, 32, 64, 128], + num_hidden_layers: int = 7, + num_attention_heads: int = 4, + intermediate_size: int = 37, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.1, + attention_probs_dropout_prob: float = 0.1, + type_sequence_label_size: int = 10, + initializer_range: float = 0.02, + encoder_stride: int = 2, + num_attention_outputs: int = 1, + dim: int = 128, + depths: List[int] = [2, 2, 2, 2], + resolution: int = 2, + mlp_expansion_ratio: int = 2, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.encoder_stride = encoder_stride + self.num_attention_outputs = num_attention_outputs + self.embed_dim = embed_dim + self.seq_length = embed_dim + 1 + self.resolution = resolution + self.depths = depths + self.hidden_sizes = hidden_sizes + self.dim = dim + self.mlp_expansion_ratio = mlp_expansion_ratio + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return EfficientFormerConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + encoder_stride=self.encoder_stride, + resolution=self.resolution, + depths=self.depths, + hidden_sizes=self.hidden_sizes, + dim=self.dim, + mlp_expansion_ratio=self.mlp_expansion_ratio, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = TFEfficientFormerModel(config=config) + result = model(pixel_values, training=False) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = TFEfficientFormerForImageClassification(config) + result = model(pixel_values, labels=labels, training=False) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + # test greyscale images + config.num_channels = 1 + model = TFEfficientFormerForImageClassification(config) + + pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFEfficientFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_tf_common.py, as EfficientFormer does not use input_ids, + inputs_embeds, attention_mask and seq_length. + """ + + all_model_classes = ( + ( + TFEfficientFormerModel, + TFEfficientFormerForImageClassificationWithTeacher, + TFEfficientFormerForImageClassification, + ) + if is_tf_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": TFEfficientFormerModel, + "image-classification": ( + TFEfficientFormerForImageClassification, + TFEfficientFormerForImageClassificationWithTeacher, + ), + } + if is_tf_available() + else {} + ) + + fx_compatible = False + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = TFEfficientFormerModelTester(self) + self.config_tester = ConfigTester( + self, config_class=EfficientFormerConfig, has_text_modality=False, hidden_size=37 + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="EfficientFormer does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="EfficientFormer does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: + seq_length = seq_length * self.model_tester.chunk_length + else: + seq_length = self.model_tester.seq_length + + self.assertListEqual( + list(hidden_states[-1].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + self.asseretIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + + self.assertListEqual( + list(hidden_states[-1].shape[-2:]), + [decoder_seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "TFEfficientFormerForImageClassificationWithTeacher": + del inputs_dict["labels"] + + return inputs_dict + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="EfficientFormer does not implement masked image modeling yet") + def test_for_masked_image_modeling(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFEfficientFormerModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_attention_outputs) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) + + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_attention_outputs) + + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_tf +@require_vision +class EfficientFormerModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return ( + EfficientFormerImageProcessor.from_pretrained("snap-research/efficientformer-l1-300") + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = TFEfficientFormerForImageClassification.from_pretrained("snap-research/efficientformer-l1-300") + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="tf") + # forward pass + outputs = model(**inputs, training=False) + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + expected_slice = tf.constant([-0.0555, 0.4825, -0.0852]) + self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_image_classification_head_with_teacher(self): + model = TFEfficientFormerForImageClassificationWithTeacher.from_pretrained( + "snap-research/efficientformer-l1-300" + ) + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="tf") + # forward pass + outputs = model(**inputs, training=False) + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + expected_slice = tf.constant([-0.1312, 0.4353, -1.0499]) + self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index d3143b2167bf9f..5cdded64bbe8fb 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -82,6 +82,7 @@ src/transformers/models/dpt/modeling_dpt.py src/transformers/models/electra/configuration_electra.py src/transformers/models/electra/modeling_electra.py src/transformers/models/electra/modeling_tf_electra.py +src/transformers/models/efficientformer/modeling_tf_efficientformer.py src/transformers/models/ernie/configuration_ernie.py src/transformers/models/ernie_m/configuration_ernie_m.py src/transformers/models/ernie_m/modeling_ernie_m.py