Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: keras fit tests for segformer tf and minor refactors. #18412

Merged
merged 3 commits into from
Aug 3, 2022
Merged

fix: keras fit tests for segformer tf and minor refactors. #18412

merged 3 commits into from
Aug 3, 2022

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 2, 2022

Fixes the issues as noticed in: https://github.com/huggingface/transformers/runs/7485048615?check_suite_focus=true.

I don't have access to an instance having multiple GPUs at the moment, but I figured out the root cause of the issue.

model_weights = model.get_weights()

^ I wasn't calling the model on some sample inputs, which is why the weights retrieved from get_weights() were zero. That has been fixed in this PR.

I tested it locally in isolation with the following snippet (I acknowledge that it's not super clean):

from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, SegformerConfig

import tensorflow as tf

from tests.test_modeling_tf_common import floats_tensor, ids_tensor
import numpy as np

batch_size = 13
image_size = 64
num_channels = 3
num_encoder_blocks = 4
depths = [2, 2, 2, 2]
sr_ratios = [8, 4, 2, 1]
hidden_sizes = [16, 32, 64, 128]
downsampling_rates = [1, 4, 8, 16]
num_attention_heads = [1, 2, 4, 8]
is_training = True
use_labels = True
hidden_act = "gelu"
hidden_dropout_prob = 0.1
attention_probs_dropout_prob = 0.1
initializer_range = 0.02
num_labels = 3


def get_config():
    return SegformerConfig(
        image_size=image_size,
        num_channels=num_channels,
        num_encoder_blocks=num_encoder_blocks,
        depths=depths,
        hidden_sizes=hidden_sizes,
        num_attention_heads=num_attention_heads,
        hidden_act=hidden_act,
        hidden_dropout_prob=hidden_dropout_prob,
        attention_probs_dropout_prob=attention_probs_dropout_prob,
        initializer_range=initializer_range,
        num_labels=num_labels
    )

def prepare_config_and_inputs(for_semseg=True):
    pixel_values = floats_tensor([batch_size, num_channels, image_size, image_size])

    if for_semseg:
        labels = ids_tensor([batch_size, image_size, image_size], num_labels)
    else:
        labels = tf.zeros((batch_size))

    config = get_config()
    return config, pixel_values, labels


model_classes = (TFSegformerForImageClassification, TFSegformerForSemanticSegmentation)

for model_class in model_classes:
    if model_class == TFSegformerForSemanticSegmentation:
        config, pixel_values, labels = prepare_config_and_inputs(for_semseg=True)
    else:
        config, pixel_values, labels = prepare_config_and_inputs(for_semseg=False)
    
    input_for_model_fit = {"pixel_values": pixel_values, "labels": labels}

    model = model_class(config)
    model(model.dummy_inputs)
    model_weights = model.get_weights()
    
    model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
    
    history1 = model.fit(
        input_for_model_fit,
        validation_data=input_for_model_fit,
        steps_per_epoch=1,
        validation_steps=1,
        shuffle=False,
    )
    val_loss1 = history1.history["val_loss"][0]

    label_names = {"labels"}
    
    labels = {key: val for key, val in input_for_model_fit.items() if key in label_names}
    inputs_minus_labels = {key: val for key, val in input_for_model_fit.items() if key not in label_names}

    # We reinitialize the model here even though our learning rate was zero
    # because BatchNorm updates weights by means other than gradient descent.
    model.set_weights(model_weights)
    history2 = model.fit(
        inputs_minus_labels,
        labels,
        validation_data=(inputs_minus_labels, labels),
        steps_per_epoch=1,
        validation_steps=1,
        shuffle=False,
    )
    val_loss2 = history2.history["val_loss"][0]

    print(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))

@amyeroberts @Rocketknight1 @sgugger

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 2, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks okay to me but will defer to the TensorFlow experts :-)
Thanks for fixing!

@Rocketknight1
Copy link
Member

Pinging @gante as this week's TF reviewer!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

(question: why is test_keras_fit entirely overwritten?)

@sayakpaul
Copy link
Member Author

(question: why is test_keras_fit entirely overwritten?)

  1. The TFSegFormerModel class doesn't support the fit test since we can't compute loss on embeddings.
  2. The labels for the rest of the two classes (semantic segmentation and image classification) have different label shapes.

So, it made sense to test them in isolation.

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 2, 2022

(question: why is test_keras_fit entirely overwritten?)

  1. The TFSegFormerModel class doesn't support the fit test since we can't compute loss on embeddings.

The line if getattr(model, "hf_compute_loss", None): should already take care of this case, I think.

  1. The labels for the rest of the two classes (semantic segmentation and image classification) have different label shapes.

Does the main issue come from the fact that _prepare_for_class in tests/test_modeling_tf_common.py lack the label preparation for segmentation?

@sayakpaul
Copy link
Member Author

Does the main issue come from the fact that _prepare_for_class in tests/test_modeling_tf_common.py lack the label preparation for segmentation?

I think so, yes.

@sayakpaul sayakpaul requested a review from gante August 3, 2022 05:08
@sayakpaul
Copy link
Member Author

Looks like the new test_keras_fit() in the base test_modeling_tf_common takes care of the nuances I faced when I was overriding test_keras_fit() (at the time of writing modeling_tf_segformer.py.

So, I incorporated the latest changes, bypassing the complete rewrite.

@ydshieh @amyeroberts @gante up for another review.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for double check, @sayakpaul . Happy to see we don't have to completely rewrite the test.

(I only look the latest change in test_keras_fit)

@@ -331,64 +329,26 @@ def recursive_check(tuple_object, dict_object):

# todo: incorporate label support for semantic segmentation in `test_modeling_tf_common.py`.

@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we haven't verified this with TF 2.9?

Once a new TF version supports this OP on CPU, it's good for us to add a version check inside skipIf.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for digging into this and fixing 🔧

@sayakpaul
Copy link
Member Author

Thanks for flagging this to me!

@ydshieh @gante okay to merge?

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 3, 2022

Let gante push the final approval button 😄

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fewer lines = <3

Thank you for having a look at the test @sayakpaul!

@gante gante merged commit be41eaf into huggingface:main Aug 3, 2022
@gante
Copy link
Member

gante commented Aug 3, 2022

@sayakpaul our CI failed in the reworked test -- can you confirm that it runs correctly? :)

https://github.com/huggingface/transformers/runs/7655675934?check_suite_focus=true

@sayakpaul
Copy link
Member Author

@gante taking a quick look here, seems like it's happening because of the second point here. If this is the case, I will sync with @ydshieh to add support for segmentation labels in the necessary places.

Sounds good?

oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
…ce#18412)

* fix: keras fit tests for segformer tf and minor refactors.

* refactor: test_keras_fit to make it simpler using the existing one.

* fix: styling issues.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants