Skip to content

Commit

Permalink
Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING (#18469)
Browse files Browse the repository at this point in the history
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Aug 4, 2022
1 parent 0bf1e1a commit 1492892
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,7 @@
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"TF_MODEL_FOR_PRETRAINING_MAPPING",
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
Expand Down Expand Up @@ -4582,6 +4583,7 @@
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"TF_MODEL_FOR_PRETRAINING_MAPPING",
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
Expand Down Expand Up @@ -253,6 +254,7 @@
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1352,8 +1352,8 @@ def masked_loss(real, pred):
loss_ = loss_fct(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask

return tf.reduce_sum(loss_) / tf.reduce_sum(mask)
reduced_masked_loss = tf.reduce_sum(loss_) / tf.reduce_sum(mask)
return tf.reshape(reduced_masked_loss, (1,))

main_loss = masked_loss(labels, upsampled_logits)
auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states


Expand Down Expand Up @@ -276,13 +276,13 @@ def __init__(
self.dense2 = tf.keras.layers.Dense(out_features, name="dense2")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
hidden_states = self.dense1(hidden_states)
hidden_states = self.depthwise_convolution(hidden_states, height, width)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.dense2(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states


Expand Down Expand Up @@ -749,7 +749,7 @@ def __init__(self, config: SegformerConfig, **kwargs):

self.config = config

def call(self, encoder_hidden_states):
def call(self, encoder_hidden_states, training: bool = False):
batch_size = shape_list(encoder_hidden_states[-1])[0]

all_hidden_states = ()
Expand All @@ -773,9 +773,9 @@ def call(self, encoder_hidden_states):
all_hidden_states += (encoder_hidden_state,)

hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.batch_norm(hidden_states, training=training)
hidden_states = self.activation(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)

# logits of shape (batch_size, height/4, width/4, num_labels)
logits = self.classifier(hidden_states)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ def __init__(self, *args, **kwargs):
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None


TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = None


TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None


Expand Down
4 changes: 4 additions & 0 deletions tests/models/segformer/test_modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


if is_tf_available():
import numpy as np
import tensorflow as tf

from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel
Expand Down Expand Up @@ -336,6 +337,9 @@ def recursive_check(tuple_object, dict_object):
def test_dataset_conversion(self):
super().test_dataset_conversion()

def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1):
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))

@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
Expand Down
18 changes: 16 additions & 2 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@
from transformers import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
Expand Down Expand Up @@ -170,6 +172,15 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
elif model_class in get_values(TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
num_patches = self.model_tester.image_size // self.model_tester.patch_size
inputs_dict["bool_masked_pos"] = tf.zeros(
(self.model_tester.batch_size, num_patches**2), dtype=tf.int32
)
elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING):
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32)

return inputs_dict

def test_initialization(self):
Expand Down Expand Up @@ -1389,6 +1400,9 @@ def test_loss_computation(self):

self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))

def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
Expand Down Expand Up @@ -1468,7 +1482,7 @@ def test_keras_fit(self):
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(not isnan(val_loss2))
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
self.check_keras_fit_results(val_loss1, val_loss2)
self.assertEqual(history1.history.keys(), history2.history.keys())
for key in history1.history.keys():
if not key.startswith("val_"):
Expand All @@ -1494,7 +1508,7 @@ def test_keras_fit(self):
val_loss3 = history3.history["val_loss"][0]
self.assertTrue(not isnan(val_loss3))
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
self.check_keras_fit_results(val_loss1, val_loss3)
self.assertEqual(history1.history.keys(), history3.history.keys())
if metrics:
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
Expand Down

0 comments on commit 1492892

Please sign in to comment.