Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING #18469

Merged
merged 7 commits into from
Aug 4, 2022

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Aug 4, 2022

What does this PR do?

The original goal is to fix TFSegformerModelTest.test_keras_fit, but it ends up the following

  • Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING to some __init__ files.
  • Add training arguments in a few layers for TFSegformerModel
  • Update _prepare_for_class to deal with 2 more image tasks
  • Fix TFData2VecVisionForSemanticSegmentation loss: we need batch dimension (without this, test_dataset_conversion failed - this was previously skipped due to the lack of labels)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 4, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh changed the title Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING Aug 4, 2022
@@ -1388,6 +1399,9 @@ def test_loss_computation(self):

self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For TFSegformerForSemanticSegmentation, we need higher tolerances. See the comment for the change in that model file.

Adding check_keras_fit_results here to avoid overwrite test_keras_fit entirely.

@@ -336,6 +337,9 @@ def recursive_check(tuple_object, dict_object):
def test_dataset_conversion(self):
super().test_dataset_conversion()

def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use higher tolerance for TFSegformerForSemanticSegmentation: this model

-has BatchNormalization layer,

  • also have several dropout layers,
  • as well as a layer TFSegformerDropPath which has random operation during training.

These factors together cause the statistic of moving_average and moving_variance different, and we have larger validation loss.

cc @sayakpaul @amyeroberts

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found that we need larger tolerances for dense prediction tasks like semantic segmentation. It was the case for TFData2VecVisionForSemanticSegmentation as well.

hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still pass this argument, right?

Copy link
Member

@sayakpaul sayakpaul Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid it since these are supposed to be set automatically during training by the Keras engine. But I find that explicitly specifying it gives me mental peace.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the information!

@@ -773,9 +773,9 @@ def call(self, encoder_hidden_states):
all_hidden_states += (encoder_hidden_state,)

hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.batch_norm(hidden_states, training=training)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 4, 2022

Test failures are ValueError: Connection error - irrelevant.

@sayakpaul
Copy link
Member

Thank you, @ydshieh for this. I appreciate the help.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing this!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for the fix ❤️

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀 this needed more changes than I expected. Thank you for looking into it, @ydshieh! 🙏

@ydshieh ydshieh merged commit 1492892 into huggingface:main Aug 4, 2022
@ydshieh ydshieh deleted the add_seg_label_to_test branch August 4, 2022 18:41
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants