Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorFlow implementation of EfficientFormer #22620

Merged
merged 32 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1ae220c
Add tf code for efficientformer
D-Roberts Apr 5, 2023
9bc3b65
Fix return dict bug - return last hidden state after last stage
D-Roberts May 13, 2023
3d0e0d8
Fix corresponding return dict bug
D-Roberts May 13, 2023
fc7c593
Override test tol
D-Roberts May 13, 2023
73fa8a1
Change default values of training to False
D-Roberts May 14, 2023
957089e
Set training to default False X3
D-Roberts May 14, 2023
d01293a
Rm axis from ln
D-Roberts May 14, 2023
5fc79df
Set init in dense projection
D-Roberts May 14, 2023
3a300e3
Rm debug stuff
D-Roberts May 14, 2023
2a08237
Make style; all tests pass.
D-Roberts May 14, 2023
6b00027
Modify year to 2023
D-Roberts May 14, 2023
b02b574
Fix attention biases codes
D-Roberts May 20, 2023
5e266f6
Update the shape list logic
D-Roberts May 20, 2023
6f95124
Add a batch norm eps config
D-Roberts May 20, 2023
3f13aa1
Remove extract comments in test files
D-Roberts May 20, 2023
427d02e
Add conditional attn and hidden states return for serving output
D-Roberts May 20, 2023
2827456
Change channel dim checking logic
D-Roberts May 20, 2023
046a2eb
Add exception for withteacher model in training mode
D-Roberts May 20, 2023
c6e56dd
Revert layer count for now
D-Roberts May 20, 2023
919eaea
Add layer count for conditional layer naming
D-Roberts May 21, 2023
683afc9
Transpose for conv happens only in main layer
D-Roberts May 21, 2023
ef03fe5
Make tests smaller
D-Roberts May 21, 2023
289f0e8
Make style
D-Roberts May 21, 2023
7b6aaf8
Update doc
D-Roberts May 23, 2023
e7232d9
Rm from_pt
D-Roberts May 25, 2023
812218a
Change to actual expect image class label
D-Roberts May 26, 2023
25e0b77
Remove stray print in tests
D-Roberts May 29, 2023
f145693
Update image processor test
D-Roberts May 29, 2023
56436b9
Remove the old serving output logic
D-Roberts May 29, 2023
c49d0f8
Make style
D-Roberts May 29, 2023
2319363
Make style
D-Roberts May 29, 2023
a2b9995
Complete test
D-Roberts May 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ Flax), PyTorch, and/or TensorFlow.
| DonutSwin | ❌ | ❌ | ✅ | ❌ | ❌ |
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
| DPT | ❌ | ❌ | ✅ | ❌ | ❌ |
| EfficientFormer | ❌ | ❌ | ✅ | | ❌ |
| EfficientFormer | ❌ | ❌ | ✅ | | ❌ |
| EfficientNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
Expand Down
17 changes: 16 additions & 1 deletion docs/source/en/model_doc/efficientformer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work pr
reach extremely low latency on mobile devices while maintaining high performance.*

This model was contributed by [novice03](https://huggingface.co/novice03) and [Bearnardd](https://huggingface.co/Bearnardd).
The original code can be found [here](https://github.com/snap-research/EfficientFormer).
The original code can be found [here](https://github.com/snap-research/EfficientFormer). The TensorFlow version of this model was added by [D-Roberts](https://huggingface.co/D-Roberts).

## Documentation resources

Expand Down Expand Up @@ -66,3 +66,18 @@ The original code can be found [here](https://github.com/snap-research/Efficient

[[autodoc]] EfficientFormerForImageClassificationWithTeacher
- forward

## TFEfficientFormerModel

[[autodoc]] TFEfficientFormerModel
- call

## TFEfficientFormerForImageClassification

[[autodoc]] TFEfficientFormerForImageClassification
- call

## TFEfficientFormerForImageClassificationWithTeacher

[[autodoc]] TFEfficientFormerForImageClassificationWithTeacher
- call
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,15 @@
"TFDPRReader",
]
)
_import_structure["models.efficientformer"].extend(
[
"TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEfficientFormerForImageClassification",
"TFEfficientFormerForImageClassificationWithTeacher",
"TFEfficientFormerModel",
"TFEfficientFormerPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
[
"TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -6471,6 +6480,13 @@
TFDPRQuestionEncoder,
TFDPRReader,
)
from .models.efficientformer import (
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEfficientFormerForImageClassification,
TFEfficientFormerForImageClassificationWithTeacher,
TFEfficientFormerModel,
TFEfficientFormerPreTrainedModel,
)
from .models.electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
("deit", "TFDeiTModel"),
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("efficientformer", "TFEfficientFormerModel"),
("electra", "TFElectraModel"),
("esm", "TFEsmModel"),
("flaubert", "TFFlaubertModel"),
Expand Down Expand Up @@ -202,6 +203,10 @@
("cvt", "TFCvtForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
(
"efficientformer",
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
),
("mobilevit", "TFMobileViTForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
Expand Down
35 changes: 34 additions & 1 deletion src/transformers/models/efficientformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)


_import_structure = {
Expand Down Expand Up @@ -45,6 +51,20 @@
"EfficientFormerPreTrainedModel",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_efficientformer"] = [
"TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEfficientFormerForImageClassification",
"TFEfficientFormerForImageClassificationWithTeacher",
"TFEfficientFormerModel",
"TFEfficientFormerPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_efficientformer import EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, EfficientFormerConfig

Expand All @@ -69,6 +89,19 @@
EfficientFormerModel,
EfficientFormerPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_efficientformer import (
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEfficientFormerForImageClassification,
TFEfficientFormerForImageClassificationWithTeacher,
TFEfficientFormerModel,
TFEfficientFormerPreTrainedModel,
)

else:
import sys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class EfficientFormerConfig(PretrainedConfig):
The size of the key in meta3D block.
attention_ratio (`int`, *optional*, defaults to 4):
Ratio of the dimension of the query and value to the dimension of the key in MSHA block
resolution (`int`, *optional*, defaults to 5)
resolution (`int`, *optional*, defaults to 7)
D-Roberts marked this conversation as resolved.
Show resolved Hide resolved
Size of each patch
num_hidden_layers (`int`, *optional*, defaults to 5):
Number of hidden layers in the Transformer encoder.
Expand Down Expand Up @@ -91,6 +91,8 @@ class EfficientFormerConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to `224`):
The size (resolution) of each image.

Example:

Expand Down Expand Up @@ -136,6 +138,8 @@ def __init__(
hidden_act: str = "gelu",
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
image_size: int = 224,
batch_norm_eps: float = 1e-05,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -165,3 +169,5 @@ def __init__(
self.distillation = distillation
self.use_layer_scale = use_layer_scale
self.layer_scale_init_value = layer_scale_init_value
self.image_size = image_size
self.batch_norm_eps = batch_norm_eps
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

# Base docstring
_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_EXPECTED_OUTPUT_SHAPE = [1, 49, 448]

# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, config: EfficientFormerConfig, num_channels: int, embed_dim:
stride=config.downsample_stride,
padding=config.downsample_pad,
)
self.norm = nn.BatchNorm2d(embed_dim) if apply_norm else nn.Identity()
self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity()

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
Expand Down Expand Up @@ -157,10 +157,10 @@ def __init__(self, config: EfficientFormerConfig, out_channels: int):
super().__init__()

self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1)
self.batchnorm_before = nn.BatchNorm2d(out_channels // 2)
self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps)

self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1)
self.batchnorm_after = nn.BatchNorm2d(out_channels)
self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)

self.activation = nn.ReLU()

Expand Down Expand Up @@ -224,24 +224,24 @@ def __init__(
hidden_features = hidden_features or in_features

self.convolution1 = nn.Conv2d(in_features, hidden_features, 1)
self.actvation = ACT2FN[config.hidden_act]
self.activation = ACT2FN[config.hidden_act]
self.convolution2 = nn.Conv2d(hidden_features, out_features, 1)
self.dropout = nn.Dropout(drop)

self.batchnorm_before = nn.BatchNorm2d(hidden_features)
self.batchnorm_after = nn.BatchNorm2d(out_features)
self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps)
self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps)

def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.convolution1(hidden_state)
hidden_state = self.batchnorm_before(hidden_state)

hidden_state = self.actvation(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.dropout(hidden_state)
hidden_state = self.convolution2(hidden_state)

hidden_state = self.batchnorm_after(hidden_state)

hidden_state = self.dropout(hidden_state)

return hidden_state


Expand All @@ -266,7 +266,7 @@ def drop_path(input, drop_prob: float = 0.0, training: bool = False):
return output


# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->EfficientFormer
class EfficientFormerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

Expand Down Expand Up @@ -301,8 +301,10 @@ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0
attention_ratio=config.attention_ratio,
resolution=config.resolution,
)
self.layernorm1 = nn.LayerNorm(dim)
self.layernorm2 = nn.LayerNorm(dim)

self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)

mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim)

Expand Down Expand Up @@ -346,15 +348,20 @@ def __init__(self, config: EfficientFormerConfig):

def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
all_attention_outputs = () if output_attentions else None

for layer_module in self.blocks:
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]

hidden_states = layer_module(hidden_states, output_attentions)

if output_attentions:
all_attention_outputs = all_attention_outputs + (hidden_states[1],)

if output_attentions:
outputs = (hidden_states[0],) + all_attention_outputs
return outputs

return hidden_states


Expand All @@ -379,6 +386,7 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:

if self.use_layer_scale:
layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs)

layer_output = layer_output + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output)
)
Expand All @@ -398,6 +406,7 @@ def __init__(self, config: EfficientFormerConfig, stage_idx: int):
drop_paths = [
config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
]

self.blocks = nn.ModuleList(
[
EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path)
Expand Down Expand Up @@ -446,6 +455,7 @@ def __init__(self, config: EfficientFormerConfig):
for i in range(num_intermediate_stages)
]
intermediate_stages = []

for i in range(num_intermediate_stages):
intermediate_stages.append(EfficientFormerIntermediateStage(config, i))
if downsamples[i]:
Expand Down Expand Up @@ -475,14 +485,15 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

layer_output = self.last_stage(hidden_states, output_attentions=output_attentions)

if output_attentions:
all_self_attentions = all_self_attentions + layer_output[1:]

if output_hidden_states:
all_hidden_states = all_hidden_states + (layer_output[0],)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
D-Roberts marked this conversation as resolved.
Show resolved Hide resolved

return BaseModelOutput(
last_hidden_state=layer_output[0],
Expand Down
Loading