Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting nan losses after gradient reversal #5

Closed
chaitrasj opened this issue Nov 3, 2020 · 5 comments
Closed

Getting nan losses after gradient reversal #5

chaitrasj opened this issue Nov 3, 2020 · 5 comments

Comments

@chaitrasj
Copy link

Hello,
I have a feature extractor network which is the trunk, which then splits as 2 branches, one as the identity classifier and the other as a color classifier. Now I want to use the gradient reversal on the color classifier output so that the features learned cannot predict the color correctly.
self.classifier_color = torch.nn.Sequential(nn.Linear(2048, 128, bias=False), nn.Linear(128, 7, bias=False), RevGrad())
This is my color branch and I have added the cross-entropy loss on the output from this self.classifier_color
However, I have been getting nan loss after 4-5 iterations. The softmax values are becoming all 0's and one of them 1. It's happening with random labels but that is causing the nan losses.
Am I doing anything incorrect in using this Gradient reversal layer?
Let me know if anything else is to be added and kindly help
Thanks

@janfreyberg
Copy link
Owner

janfreyberg commented Nov 14, 2020

Hello,

sorry for the delayed response to this issue. Is there any way you could post some minimal example code? Maybe using random numbers as input instead of your data.

I think one issue for you could be that you place the RevGrad layer right before your loss. I would expect that to lead to exploding weights in some direction.

The way these layers are usually used is to place the RevGrad layer after some "common" part of the network, but before the layers that predict the protected category / domain variable. Otherwise, there would be no weights at all that learn to predict colour, which would mean the loss for that branch explodes (and eventually becomes NaN when it's too big).

Let me know if this solves your problem!

For example, the following works for me:

import numpy as np
from itertools import chain
import torch.nn
import pytorch_revgrad

stem = torch.nn.Sequential(
    torch.nn.Linear(128, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 512),
    torch.nn.ReLU(),
    torch.nn.Linear(512, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
)

identity_classifier = torch.nn.Sequential(
    torch.nn.Linear(64, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 10),
)

colour_classifier = torch.nn.Sequential(
    pytorch_revgrad.RevGrad(),
    torch.nn.Linear(64, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 2),
)

# simulate a training run:
sgd = torch.optim.Adam(
    chain(stem.parameters(),
          identity_classifier.parameters(),
          colour_classifier.parameters()),
    lr=0.01
)

input_data = torch.randn(160, 128)
identity = torch.randint(0, 9, (160,))
colour = torch.randint(0, 2, (160,))

data = torch.utils.data.TensorDataset(input_data, identity, colour)
loader = torch.utils.data.DataLoader(data, batch_size=8)

alpha = 0.1


for epoch in range(100):
    for inp, iden, col in loader:

        intermediate_features = stem(inp)
        identity_logits = identity_classifier(intermediate_features)
        colour_logits = colour_classifier(intermediate_features)

        identity_loss = torch.nn.functional.cross_entropy(identity_logits, iden)
        colour_loss = torch.nn.functional.cross_entropy(colour_logits, col)
        total_loss = identity_loss + alpha * colour_loss
        total_loss.backward()
        sgd.step()
        sgd.zero_grad()
    print(f"identity loss: {identity_loss.data}", f"col loss: {colour_loss.data}")

You'll see that the loss for identity eventually goes down, and the loss for colour stays roughly the same. If you move the RevGrad layer to the end of colour_classifier, the loss quickly explodes.

@janfreyberg
Copy link
Owner

Closing for now, but feel free to re-open.

@surajn28
Copy link

Hello @janfreyberg I facing the same issue. I have place my gradient reversal layer after some common part of the network.

@janfreyberg
Copy link
Owner

This is difficult to debug without example code.

@junzai0215
Copy link

junzai0215 commented Nov 12, 2023

Can these three classes (stem, identity_classifier and colour_classifier) be integrated in a network? Thank you! For example:

import numpy as np
from itertools import chain
import torch.nn
import pytorch_revgrad

class Classifiers(nn.Module):
    def __init__(self):
        super(Classifiers, self).__init__()
        self.stem = torch.nn.Sequential(
           torch.nn.Linear(128, 256),
           torch.nn.ReLU(),
           torch.nn.Linear(256, 512),
           torch.nn.ReLU(),
           torch.nn.Linear(512, 128),
           torch.nn.ReLU(),
           torch.nn.Linear(128, 64),
        )

       self.identity_classifier = torch.nn.Sequential(
          torch.nn.Linear(64, 64),
          torch.nn.ReLU(),
          torch.nn.Linear(64, 10),
       )

       self.colour_classifier = torch.nn.Sequential(
          pytorch_revgrad.RevGrad(),
          torch.nn.Linear(64, 64),
          torch.nn.ReLU(),
          torch.nn.Linear(64, 2),
       )
      def forward(self, inp):
         intermediate_features = self.stem(inp)
         identity_logits = self.identity_classifier(intermediate_features)
         colour_logits = self.colour_classifier(intermediate_features)
         return identity_logits, colour_logits

for epoch in range(100):
    for inp, iden, col in loader:
        identity_logits, colour_logits  = Classifiers(inp)
        identity_loss = torch.nn.functional.cross_entropy(identity_logits, iden)
        colour_loss = torch.nn.functional.cross_entropy(colour_logits, col)
        total_loss = identity_loss + alpha * colour_loss
        total_loss.backward()
        ......
  

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants