Skip to content

Commit

Permalink
Swin main layer (#17693)
Browse files Browse the repository at this point in the history
* Swin models call TFSwinMainLayer

* Tidy up
  • Loading branch information
amyeroberts authored Jun 14, 2022
1 parent 3960ce9 commit bd43151
Showing 1 changed file with 68 additions and 18 deletions.
86 changes: 68 additions & 18 deletions src/transformers/models/swin/modeling_tf_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
import tensorflow as tf

from ...activations_tf import ACT2FN
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, get_initializer, unpack_inputs
from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -1069,15 +1075,14 @@ def get_config(self) -> Dict[str, Any]:
return {**base_config, **config}


@add_start_docstrings(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
SWIN_START_DOCSTRING,
)
class TFSwinModel(TFSwinPreTrainedModel):
@keras_serializable
class TFSwinMainLayer(tf.keras.layers.Layer):
config_class = SwinConfig

def __init__(
self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
) -> None:
super().__init__(config, **kwargs)
super().__init__(**kwargs)
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
Expand All @@ -1104,15 +1109,6 @@ def get_head_mask(self, head_mask: Optional[Any]) -> List:
raise NotImplementedError
return [None] * len(self.config.depths)

@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSwinModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
@unpack_inputs
def call(
self,
Expand Down Expand Up @@ -1175,6 +1171,60 @@ def call(
)


@add_start_docstrings(
"The bare Swin Model transformer outputting raw hidden-states without any specific head on top.",
SWIN_START_DOCSTRING,
)
class TFSwinModel(TFSwinPreTrainedModel):
def __init__(
self, config: SwinConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, **kwargs
) -> None:
super().__init__(config, **kwargs)
self.config = config
self.swin = TFSwinMainLayer(config, name="swin")

@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSwinModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
@unpack_inputs
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
bool_masked_pos: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFSwinModelOutput, Tuple[tf.Tensor, ...]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if pixel_values is None:
raise ValueError("You have to specify pixel_values")

swin_outputs = self.swin(
pixel_values=pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)

return swin_outputs


class PixelShuffle(tf.keras.layers.Layer):
"""TF layer implementation of torch.nn.PixelShuffle"""

Expand Down Expand Up @@ -1238,7 +1288,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
def __init__(self, config: SwinConfig):
super().__init__(config)

self.swin = TFSwinModel(config, add_pooling_layer=False, use_mask_token=True, name="swin")
self.swin = TFSwinMainLayer(config, add_pooling_layer=False, use_mask_token=True, name="swin")

self.decoder = TFSwinDecoder(config, name="decoder")

Expand Down Expand Up @@ -1350,7 +1400,7 @@ def __init__(self, config: SwinConfig):
super().__init__(config)

self.num_labels = config.num_labels
self.swin = TFSwinModel(config, name="swin")
self.swin = TFSwinMainLayer(config, name="swin")

# Classifier head
self.classifier = (
Expand Down

0 comments on commit bd43151

Please sign in to comment.