diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 1346414bed97..1bfb8cfa4cc9 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -271,7 +271,7 @@ Flax), PyTorch, and/or TensorFlow. | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | -| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/resnet.mdx b/docs/source/en/model_doc/resnet.mdx index 88131c24ba1e..3c8af6227d19 100644 --- a/docs/source/en/model_doc/resnet.mdx +++ b/docs/source/en/model_doc/resnet.mdx @@ -31,7 +31,7 @@ The figure below illustrates the architecture of ResNet. Taken from the [origina -This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks). +This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of this model was added by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks). ## ResNetConfig @@ -47,4 +47,16 @@ This model was contributed by [Francesco](https://huggingface.co/Francesco). The ## ResNetForImageClassification [[autodoc]] ResNetForImageClassification - - forward \ No newline at end of file + - forward + + +## TFResNetModel + +[[autodoc]] TFResNetModel + - call + + +## TFResNetForImageClassification + +[[autodoc]] TFResNetForImageClassification + - call diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dabc7e2ea81e..d30cab3a5649 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2369,6 +2369,14 @@ "TFRemBertPreTrainedModel", ] ) + _import_structure["models.resnet"].extend( + [ + "TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFResNetForImageClassification", + "TFResNetModel", + "TFResNetPreTrainedModel", + ] + ) _import_structure["models.roberta"].extend( [ "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4699,6 +4707,12 @@ TFRemBertModel, TFRemBertPreTrainedModel, ) + from .models.resnet import ( + TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, + TFResNetForImageClassification, + TFResNetModel, + TFResNetPreTrainedModel, + ) from .models.roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TFRobertaForCausalLM, diff --git a/src/transformers/modeling_tf_outputs.py b/src/transformers/modeling_tf_outputs.py index 5aedace141c1..30277f925923 100644 --- a/src/transformers/modeling_tf_outputs.py +++ b/src/transformers/modeling_tf_outputs.py @@ -62,7 +62,7 @@ class TFBaseModelOutputWithNoAttention(ModelOutput): """ last_hidden_state: tf.Tensor = None - hidden_states: Optional[Tuple[tf.Tensor]] = None + hidden_states: Optional[Tuple[tf.Tensor, ...]] = None @dataclass @@ -118,7 +118,7 @@ class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput): last_hidden_state: tf.Tensor = None pooler_output: tf.Tensor = None - hidden_states: Optional[Tuple[tf.Tensor]] = None + hidden_states: Optional[Tuple[tf.Tensor, ...]] = None @dataclass @@ -886,4 +886,4 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput): loss: Optional[tf.Tensor] = None logits: tf.Tensor = None - hidden_states: Optional[Tuple[tf.Tensor]] = None + hidden_states: Optional[Tuple[tf.Tensor, ...]] = None diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index aaf91bdb24ec..530de3b422e4 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -64,6 +64,7 @@ ("pegasus", "TFPegasusModel"), ("regnet", "TFRegNetModel"), ("rembert", "TFRemBertModel"), + ("resnet", "TFResNetModel"), ("roberta", "TFRobertaModel"), ("roformer", "TFRoFormerModel"), ("speech_to_text", "TFSpeech2TextModel"), @@ -175,6 +176,7 @@ ("convnext", "TFConvNextForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"), ("regnet", "TFRegNetForImageClassification"), + ("resnet", "TFResNetForImageClassification"), ("swin", "TFSwinForImageClassification"), ("vit", "TFViTForImageClassification"), ] diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py index e1c0a9ec84d6..f62c2999671d 100644 --- a/src/transformers/models/resnet/__init__.py +++ b/src/transformers/models/resnet/__init__.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING # rely on isort to merge the imports -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available _import_structure = { @@ -38,6 +38,19 @@ "ResNetPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_resnet"] = [ + "TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFResNetForImageClassification", + "TFResNetModel", + "TFResNetPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig @@ -55,6 +68,19 @@ ResNetPreTrainedModel, ) + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_resnet import ( + TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, + TFResNetForImageClassification, + TFResNetModel, + TFResNetPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/resnet/modeling_tf_resnet.py b/src/transformers/models/resnet/modeling_tf_resnet.py new file mode 100644 index 000000000000..c7c6c95fb818 --- /dev/null +++ b/src/transformers/models/resnet/modeling_tf_resnet.py @@ -0,0 +1,479 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TensorFlow ResNet model.""" + +from typing import Dict, Optional, Tuple, Union + +import tensorflow as tf + +from ...activations_tf import ACT2FN +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithNoAttention, + TFBaseModelOutputWithPoolingAndNoAttention, + TFImageClassifierOutputWithNoAttention, +) +from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs +from ...tf_utils import shape_list +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_resnet import ResNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ResNetConfig" +_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "microsoft/resnet-50" +_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat" + +TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/resnet-50", + # See all resnet models at https://huggingface.co/models?filter=resnet +] + + +class TFResNetConvLayer(tf.keras.layers.Layer): + def __init__( + self, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu", **kwargs + ) -> None: + super().__init__(**kwargs) + self.pad_value = kernel_size // 2 + self.conv = tf.keras.layers.Conv2D( + out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution" + ) + # Use same default momentum and epsilon as PyTorch equivalent + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + self.activation = ACT2FN[activation] if activation is not None else tf.keras.layers.Activation("linear") + + def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor: + # Pad to match that done in the PyTorch Conv2D model + height_pad = width_pad = (self.pad_value, self.pad_value) + hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)]) + hidden_state = self.conv(hidden_state) + return hidden_state + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class TFResNetEmbeddings(tf.keras.layers.Layer): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.embedder = TFResNetConvLayer( + config.embedding_size, + kernel_size=7, + stride=2, + activation=config.hidden_act, + name="embedder", + ) + self.pooler = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="pooler") + self.num_channels = config.num_channels + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + _, _, _, num_channels = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + hidden_state = pixel_values + hidden_state = self.embedder(hidden_state) + hidden_state = tf.pad(hidden_state, [[0, 0], [1, 1], [1, 1], [0, 0]]) + hidden_state = self.pooler(hidden_state) + return hidden_state + + +class TFResNetShortCut(tf.keras.layers.Layer): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, out_channels: int, stride: int = 2, **kwargs) -> None: + super().__init__(**kwargs) + self.convolution = tf.keras.layers.Conv2D( + out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution" + ) + # Use same default momentum and epsilon as PyTorch equivalent + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + + def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = x + hidden_state = self.convolution(hidden_state) + hidden_state = self.normalization(hidden_state, training=training) + return hidden_state + + +class TFResNetBasicLayer(tf.keras.layers.Layer): + """ + A classic ResNet's residual layer composed by two `3x3` convolutions. + """ + + def __init__( + self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", **kwargs + ) -> None: + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + self.conv1 = TFResNetConvLayer(out_channels, stride=stride, name="layer.0") + self.conv2 = TFResNetConvLayer(out_channels, activation=None, name="layer.1") + self.shortcut = ( + TFResNetShortCut(out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else tf.keras.layers.Activation("linear", name="shortcut") + ) + self.activation = ACT2FN[activation] + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + residual = hidden_state + hidden_state = self.conv1(hidden_state, training=training) + hidden_state = self.conv2(hidden_state, training=training) + residual = self.shortcut(residual, training=training) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class TFResNetBottleNeckLayer(tf.keras.layers.Layer): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + activation: str = "relu", + reduction: int = 4, + **kwargs + ) -> None: + super().__init__(**kwargs) + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + self.conv0 = TFResNetConvLayer(reduces_channels, kernel_size=1, name="layer.0") + self.conv1 = TFResNetConvLayer(reduces_channels, stride=stride, name="layer.1") + self.conv2 = TFResNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2") + self.shortcut = ( + TFResNetShortCut(out_channels, stride=stride, name="shortcut") + if should_apply_shortcut + else tf.keras.layers.Activation("linear", name="shortcut") + ) + self.activation = ACT2FN[activation] + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + residual = hidden_state + hidden_state = self.conv0(hidden_state, training=training) + hidden_state = self.conv1(hidden_state, training=training) + hidden_state = self.conv2(hidden_state, training=training) + residual = self.shortcut(residual, training=training) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class TFResNetStage(tf.keras.layers.Layer): + """ + A ResNet stage composed of stacked layers. + """ + + def __init__( + self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs + ) -> None: + super().__init__(**kwargs) + + layer = TFResNetBottleNeckLayer if config.layer_type == "bottleneck" else TFResNetBasicLayer + + layers = [layer(in_channels, out_channels, stride=stride, activation=config.hidden_act, name="layers.0")] + layers += [ + layer(out_channels, out_channels, activation=config.hidden_act, name=f"layers.{i + 1}") + for i in range(depth - 1) + ] + self.stage_layers = layers + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + for layer in self.stage_layers: + hidden_state = layer(hidden_state, training=training) + return hidden_state + + +class TFResNetEncoder(tf.keras.layers.Layer): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages = [ + TFResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + name="stages.0", + ) + ] + for i, (in_channels, out_channels, depth) in enumerate( + zip(config.hidden_sizes, config.hidden_sizes[1:], config.depths[1:]) + ): + self.stages.append(TFResNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}")) + + def call( + self, + hidden_state: tf.Tensor, + output_hidden_states: bool = False, + return_dict: bool = True, + training: bool = False, + ) -> TFBaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state, training=training) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return TFBaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class TFResNetPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + +RESNET_START_DOCSTRING = r""" + This model is a TensorFlow + [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a + regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@keras_serializable +class TFResNetMainLayer(tf.keras.layers.Layer): + config_class = ResNetConfig + + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + self.embedder = TFResNetEmbeddings(config, name="embedder") + self.encoder = TFResNetEncoder(config, name="encoder") + self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True) + + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TF 2.0 image layers can't use NCHW format when running on CPU. + # We transpose to NHWC format and then transpose back after the full forward pass. + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1]) + embedding_output = self.embedder(pixel_values, training=training) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + # Transpose all the outputs to the NCHW format + # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width) + last_hidden_state = tf.transpose(last_hidden_state, (0, 3, 1, 2)) + pooled_output = tf.transpose(pooled_output, (0, 3, 1, 2)) + hidden_states = () + for hidden_state in encoder_outputs[1:]: + hidden_states = hidden_states + tuple(tf.transpose(h, (0, 3, 1, 2)) for h in hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + hidden_states + + hidden_states = hidden_states if output_hidden_states else None + + return TFBaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=hidden_states, + ) + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class TFResNetModel(TFResNetPreTrainedModel): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self.resnet = TFResNetMainLayer(config=config, name="resnet") + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + resnet_outputs = self.resnet( + pixel_values=pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return resnet_outputs + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: ResNetConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + self.num_labels = config.num_labels + self.resnet = TFResNetMainLayer(config, name="resnet") + # classification head + self.classifier_layer = ( + tf.keras.layers.Dense(config.num_labels, name="classifier.1") + if config.num_labels > 0 + else tf.keras.layers.Activation("linear", name="classifier.1") + ) + + def classifier(self, x: tf.Tensor) -> tf.Tensor: + x = tf.keras.layers.Flatten()(x) + logits = self.classifier_layer(x) + return logits + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + @unpack_inputs + def call( + self, + pixel_values: tf.Tensor = None, + labels: tf.Tensor = None, + output_hidden_states: bool = None, + return_dict: bool = None, + training: bool = False, + ) -> Union[Tuple[tf.Tensor], TFImageClassifierOutputWithNoAttention]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet( + pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 01412ba82efa..6ce1692fb3dc 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1786,6 +1786,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) +TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TFResNetForImageClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFResNetModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFResNetPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/resnet/test_modeling_tf_resnet.py b/tests/models/resnet/test_modeling_tf_resnet.py new file mode 100644 index 000000000000..5f4eead8661c --- /dev/null +++ b/tests/models/resnet/test_modeling_tf_resnet.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the Tensorflow ResNet model. """ + + +import inspect +import unittest + +import numpy as np + +from transformers import ResNetConfig +from transformers.testing_utils import require_tf, require_vision, slow +from transformers.utils import cached_property, is_tf_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFResNetForImageClassification, TFResNetModel + from transformers.models.resnet.modeling_tf_resnet import TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class ResNetModelTester: + def __init__( + self, + parent, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return ResNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + image_size=self.image_size, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = TFResNetModel(config=config) + result = model(pixel_values) + # expected last hidden states: B, C, H // 32, W // 32 + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = TFResNetForImageClassification(config) + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class ResNetModelTest(TFModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (TFResNetModel, TFResNetForImageClassification) if is_tf_available() else () + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_onnx = False + has_attentions = False + + def setUp(self): + self.model_tester = ResNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + @unittest.skip(reason="ResNet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="ResNet does not output attentions") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="ResNet does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + # ResNet's feature maps are of shape (batch_size, num_channels, height, width) + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.image_size // 4, self.model_tester.image_size // 4], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + layers_type = ["basic", "bottleneck"] + for model_class in self.all_model_classes: + for layer_type in layers_type: + config.layer_type = layer_type + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = TFResNetModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_tf +@require_vision +class ResNetModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + AutoFeatureExtractor.from_pretrained(TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = TFResNetForImageClassification.from_pretrained(TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(**inputs) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = tf.constant([-11.1069, -9.7877, -8.3777]) + + self.assertTrue(np.allclose(outputs.logits[0, :3].numpy(), expected_slice, atol=1e-4)) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index a977f47783f5..b1838b2b8ec2 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -53,6 +53,7 @@ src/transformers/models/reformer/modeling_reformer.py src/transformers/models/regnet/modeling_regnet.py src/transformers/models/regnet/modeling_tf_regnet.py src/transformers/models/resnet/modeling_resnet.py +src/transformers/models/resnet/modeling_tf_resnet.py src/transformers/models/roberta/modeling_roberta.py src/transformers/models/roberta/modeling_tf_roberta.py src/transformers/models/segformer/modeling_segformer.py @@ -75,5 +76,5 @@ src/transformers/models/wav2vec2/modeling_wav2vec2.py src/transformers/models/wav2vec2/tokenization_wav2vec2.py src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py -src/transformers/models/wavlm/modeling_wavlm.py +src/transformers/models/wavlm/modeling_wavlm.py src/transformers/models/yolos/modeling_yolos.py