-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
#18469
Changes from all commits
1d8400e
4c99ddc
b9332ba
739712d
22bc799
f02c9f1
3ef8b15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,9 +201,9 @@ def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs): | |
self.dense = tf.keras.layers.Dense(hidden_size, name="dense") | ||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) | ||
|
||
def call(self, hidden_states: tf.Tensor) -> tf.Tensor: | ||
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: | ||
hidden_states = self.dense(hidden_states) | ||
hidden_states = self.dropout(hidden_states) | ||
hidden_states = self.dropout(hidden_states, training=training) | ||
return hidden_states | ||
|
||
|
||
|
@@ -276,13 +276,13 @@ def __init__( | |
self.dense2 = tf.keras.layers.Dense(out_features, name="dense2") | ||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) | ||
|
||
def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor: | ||
def call(self, hidden_states: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor: | ||
hidden_states = self.dense1(hidden_states) | ||
hidden_states = self.depthwise_convolution(hidden_states, height, width) | ||
hidden_states = self.intermediate_act_fn(hidden_states) | ||
hidden_states = self.dropout(hidden_states) | ||
hidden_states = self.dropout(hidden_states, training=training) | ||
hidden_states = self.dense2(hidden_states) | ||
hidden_states = self.dropout(hidden_states) | ||
hidden_states = self.dropout(hidden_states, training=training) | ||
return hidden_states | ||
|
||
|
||
|
@@ -749,7 +749,7 @@ def __init__(self, config: SegformerConfig, **kwargs): | |
|
||
self.config = config | ||
|
||
def call(self, encoder_hidden_states): | ||
def call(self, encoder_hidden_states, training: bool = False): | ||
batch_size = shape_list(encoder_hidden_states[-1])[0] | ||
|
||
all_hidden_states = () | ||
|
@@ -773,9 +773,9 @@ def call(self, encoder_hidden_states): | |
all_hidden_states += (encoder_hidden_state,) | ||
|
||
hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1)) | ||
hidden_states = self.batch_norm(hidden_states) | ||
hidden_states = self.batch_norm(hidden_states, training=training) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
hidden_states = self.activation(hidden_states) | ||
hidden_states = self.dropout(hidden_states) | ||
hidden_states = self.dropout(hidden_states, training=training) | ||
|
||
# logits of shape (batch_size, height/4, width/4, num_labels) | ||
logits = self.classifier(hidden_states) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
|
||
|
||
if is_tf_available(): | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel | ||
|
@@ -336,6 +337,9 @@ def recursive_check(tuple_object, dict_object): | |
def test_dataset_conversion(self): | ||
super().test_dataset_conversion() | ||
|
||
def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use higher tolerance for -has
These factors together cause the statistic of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have found that we need larger tolerances for dense prediction tasks like semantic segmentation. It was the case for |
||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol)) | ||
|
||
@unittest.skipIf( | ||
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0, | ||
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,11 +62,13 @@ | |
from transformers import ( | ||
TF_MODEL_FOR_CAUSAL_LM_MAPPING, | ||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, | ||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, | ||
TF_MODEL_FOR_MASKED_LM_MAPPING, | ||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | ||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | ||
TF_MODEL_FOR_PRETRAINING_MAPPING, | ||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, | ||
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, | ||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | ||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | ||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, | ||
|
@@ -169,6 +171,15 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d | |
inputs_dict["labels"] = tf.zeros( | ||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32 | ||
) | ||
elif model_class in get_values(TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING): | ||
num_patches = self.model_tester.image_size // self.model_tester.patch_size | ||
inputs_dict["bool_masked_pos"] = tf.zeros( | ||
(self.model_tester.batch_size, num_patches**2), dtype=tf.int32 | ||
) | ||
elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING): | ||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape | ||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32) | ||
|
||
return inputs_dict | ||
|
||
def test_initialization(self): | ||
|
@@ -1388,6 +1399,9 @@ def test_loss_computation(self): | |
|
||
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1]) | ||
|
||
def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For Adding |
||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol)) | ||
|
||
def test_keras_fit(self): | ||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() | ||
for model_class in self.all_model_classes: | ||
|
@@ -1467,7 +1481,7 @@ def test_keras_fit(self): | |
val_loss2 = history2.history["val_loss"][0] | ||
self.assertTrue(not isnan(val_loss2)) | ||
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")} | ||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) | ||
self.check_keras_fit_results(val_loss1, val_loss2) | ||
self.assertEqual(history1.history.keys(), history2.history.keys()) | ||
for key in history1.history.keys(): | ||
if not key.startswith("val_"): | ||
|
@@ -1493,7 +1507,7 @@ def test_keras_fit(self): | |
val_loss3 = history3.history["val_loss"][0] | ||
self.assertTrue(not isnan(val_loss3)) | ||
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")} | ||
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3)) | ||
self.check_keras_fit_results(val_loss1, val_loss3) | ||
self.assertEqual(history1.history.keys(), history3.history.keys()) | ||
if metrics: | ||
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still pass this argument, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can avoid it since these are supposed to be set automatically during training by the Keras engine. But I find that explicitly specifying it gives me mental peace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the information!