Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING #18469

Merged
merged 7 commits into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,7 @@
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"TF_MODEL_FOR_PRETRAINING_MAPPING",
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
Expand Down Expand Up @@ -4558,6 +4559,7 @@
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"TF_MODEL_FOR_PRETRAINING_MAPPING",
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
Expand Down Expand Up @@ -249,6 +250,7 @@
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1352,8 +1352,8 @@ def masked_loss(real, pred):
loss_ = loss_fct(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask

return tf.reduce_sum(loss_) / tf.reduce_sum(mask)
reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask)
return tf.reshape(reduced_masked_loss, (1,))

main_loss = masked_loss(labels, upsampled_logits)
auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still pass this argument, right?

Copy link
Member

@sayakpaul sayakpaul Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid it since these are supposed to be set automatically during training by the Keras engine. But I find that explicitly specifying it gives me mental peace.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the information!

return hidden_states


Expand Down Expand Up @@ -276,13 +276,13 @@ def __init__(
self.dense2 = tf.keras.layers.Dense(out_features, name="dense2")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
hidden_states = self.dense1(hidden_states)
hidden_states = self.depthwise_convolution(hidden_states, height, width)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.dense2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states


Expand Down Expand Up @@ -749,7 +749,7 @@ def __init__(self, config: SegformerConfig, **kwargs):

self.config = config

def call(self, encoder_hidden_states):
def call(self, encoder_hidden_states, training: bool = False):
batch_size = shape_list(encoder_hidden_states[-1])[0]

all_hidden_states = ()
Expand All @@ -773,9 +773,9 @@ def call(self, encoder_hidden_states):
all_hidden_states += (encoder_hidden_state,)

hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.batch_norm(hidden_states, training=training)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

hidden_states = self.activation(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)

# logits of shape (batch_size, height/4, width/4, num_labels)
logits = self.classifier(hidden_states)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ def __init__(self, *args, **kwargs):
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None


TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = None


TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None


Expand Down
4 changes: 4 additions & 0 deletions tests/models/segformer/test_modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


if is_tf_available():
import numpy as np
import tensorflow as tf

from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel
Expand Down Expand Up @@ -336,6 +337,9 @@ def recursive_check(tuple_object, dict_object):
def test_dataset_conversion(self):
super().test_dataset_conversion()

def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use higher tolerance for TFSegformerForSemanticSegmentation: this model

-has BatchNormalization layer,

  • also have several dropout layers,
  • as well as a layer TFSegformerDropPath which has random operation during training.

These factors together cause the statistic of moving_average and moving_variance different, and we have larger validation loss.

cc @sayakpaul @amyeroberts

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found that we need larger tolerances for dense prediction tasks like semantic segmentation. It was the case for TFData2VecVisionForSemanticSegmentation as well.

self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))

@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
Expand Down
18 changes: 16 additions & 2 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@
from transformers import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
Expand Down Expand Up @@ -169,6 +171,15 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
elif model_class in get_values(TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
num_patches = self.model_tester.image_size // self.model_tester.patch_size
inputs_dict["bool_masked_pos"] = tf.zeros(
(self.model_tester.batch_size, num_patches**2), dtype=tf.int32
)
elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING):
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32)

return inputs_dict

def test_initialization(self):
Expand Down Expand Up @@ -1388,6 +1399,9 @@ def test_loss_computation(self):

self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For TFSegformerForSemanticSegmentation, we need higher tolerances. See the comment for the change in that model file.

Adding check_keras_fit_results here to avoid overwrite test_keras_fit entirely.

self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))

def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
Expand Down Expand Up @@ -1467,7 +1481,7 @@ def test_keras_fit(self):
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(not isnan(val_loss2))
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.check_keras_fit_results(val_loss1, val_loss2)
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
if not key.startswith("val_"):
Expand All @@ -1493,7 +1507,7 @@ def test_keras_fit(self):
val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
self.check_keras_fit_results(val_loss1, val_loss3)
self.assertEqual(history1.history.keys(), history3.history.keys())
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
Expand Down