Skip to content

Commit

Permalink
Add TF ResNet model (#17427)
Browse files Browse the repository at this point in the history
* Rought TF conversion outline

* Tidy up

* Fix padding differences between layers

* Add back embedder - whoops

* Match test file to main

* Match upstream test file

* Correctly pass and assign image_size parameter

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Add in MainLayer

* Correctly name layer

* Tidy up AdaptivePooler

* Small tidy-up

More accurate type hints and remove whitespaces

* Change AdaptiveAvgPool

Use the AdaptiveAvgPool implementation by @Rocketknight1, which correctly pools if the output shape does not evenly divide by input shape c.f. https://github.com/huggingface/transformers/pull/17554/files/9e26607e22aa8d069c86b50196656012ff0ce62a#r900109509

Co-authored-by: From: matt <rocketknight1@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Use updated AdaptiveAvgPool

Co-authored-by: matt <rocketknight1@gmail.com>

* Make AdaptiveAvgPool compatible with CPU

* Remove image_size from configuration

* Fixup

* Tensorflow -> TensorFlow

* Fix pt references in tests

* Apply suggestions from code review - grammar and wording

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add TFResNet to doc tests

* PR comments - GlobalAveragePooling and clearer comments

* Remove unused import

* Add in keepdims argument

* Add num_channels check

* grammar fix: by -> of

Co-authored-by: matt <rocketknight1@gmail.com>

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Remove transposes - keep NHWC throughout forward pass

* Fixup look sharp

* Add missing layer names

* Final tidy up - remove from_pt now weights on hub

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: matt <rocketknight1@gmail.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
  • Loading branch information
5 people authored Jul 4, 2022
1 parent 7b18702 commit 77ea513
Show file tree
Hide file tree
Showing 10 changed files with 818 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ Flax), PyTorch, and/or TensorFlow.
| Reformer | | | | | |
| RegNet | | | | | |
| RemBERT | | | | | |
| ResNet | | | | | |
| ResNet | | | | | |
| RetriBERT | | | | | |
| RoBERTa | | | | | |
| RoFormer | | | | | |
Expand Down
16 changes: 14 additions & 2 deletions docs/source/en/model_doc/resnet.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The figure below illustrates the architecture of ResNet. Taken from the [origina

<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/resnet_architecture.png"/>

This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks).
This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of this model was added by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks).

## ResNetConfig

Expand All @@ -47,4 +47,16 @@ This model was contributed by [Francesco](https://huggingface.co/Francesco). The
## ResNetForImageClassification

[[autodoc]] ResNetForImageClassification
- forward
- forward


## TFResNetModel

[[autodoc]] TFResNetModel
- call


## TFResNetForImageClassification

[[autodoc]] TFResNetForImageClassification
- call
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2380,6 +2380,14 @@
"TFRemBertPreTrainedModel",
]
)
_import_structure["models.resnet"].extend(
[
"TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFResNetForImageClassification",
"TFResNetModel",
"TFResNetPreTrainedModel",
]
)
_import_structure["models.roberta"].extend(
[
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -4721,6 +4729,12 @@
TFRemBertModel,
TFRemBertPreTrainedModel,
)
from .models.resnet import (
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFResNetForImageClassification,
TFResNetModel,
TFResNetPreTrainedModel,
)
from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modeling_tf_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TFBaseModelOutputWithNoAttention(ModelOutput):
"""

last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None


@dataclass
Expand Down Expand Up @@ -118,7 +118,7 @@ class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):

last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None


@dataclass
Expand Down Expand Up @@ -886,4 +886,4 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):

loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"),
("resnet", "TFResNetModel"),
("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"),
("speech_to_text", "TFSpeech2TextModel"),
Expand Down Expand Up @@ -175,6 +176,7 @@
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
Expand Down
28 changes: 27 additions & 1 deletion src/transformers/models/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING

# rely on isort to merge the imports
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -38,6 +38,19 @@
"ResNetPreTrainedModel",
]

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_resnet"] = [
"TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFResNetForImageClassification",
"TFResNetModel",
"TFResNetPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
Expand All @@ -55,6 +68,19 @@
ResNetPreTrainedModel,
)

try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_resnet import (
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFResNetForImageClassification,
TFResNetModel,
TFResNetPreTrainedModel,
)


else:
import sys
Expand Down
Loading

0 comments on commit 77ea513

Please sign in to comment.