-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting nan losses after gradient reversal #5
Comments
Hello, sorry for the delayed response to this issue. Is there any way you could post some minimal example code? Maybe using random numbers as input instead of your data. I think one issue for you could be that you place the RevGrad layer right before your loss. I would expect that to lead to exploding weights in some direction. The way these layers are usually used is to place the RevGrad layer after some "common" part of the network, but before the layers that predict the protected category / domain variable. Otherwise, there would be no weights at all that learn to predict colour, which would mean the loss for that branch explodes (and eventually becomes NaN when it's too big). Let me know if this solves your problem! For example, the following works for me: import numpy as np
from itertools import chain
import torch.nn
import pytorch_revgrad
stem = torch.nn.Sequential(
torch.nn.Linear(128, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 64),
)
identity_classifier = torch.nn.Sequential(
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 10),
)
colour_classifier = torch.nn.Sequential(
pytorch_revgrad.RevGrad(),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 2),
)
# simulate a training run:
sgd = torch.optim.Adam(
chain(stem.parameters(),
identity_classifier.parameters(),
colour_classifier.parameters()),
lr=0.01
)
input_data = torch.randn(160, 128)
identity = torch.randint(0, 9, (160,))
colour = torch.randint(0, 2, (160,))
data = torch.utils.data.TensorDataset(input_data, identity, colour)
loader = torch.utils.data.DataLoader(data, batch_size=8)
alpha = 0.1
for epoch in range(100):
for inp, iden, col in loader:
intermediate_features = stem(inp)
identity_logits = identity_classifier(intermediate_features)
colour_logits = colour_classifier(intermediate_features)
identity_loss = torch.nn.functional.cross_entropy(identity_logits, iden)
colour_loss = torch.nn.functional.cross_entropy(colour_logits, col)
total_loss = identity_loss + alpha * colour_loss
total_loss.backward()
sgd.step()
sgd.zero_grad()
print(f"identity loss: {identity_loss.data}", f"col loss: {colour_loss.data}") You'll see that the loss for identity eventually goes down, and the loss for colour stays roughly the same. If you move the RevGrad layer to the end of |
Closing for now, but feel free to re-open. |
Hello @janfreyberg I facing the same issue. I have place my gradient reversal layer after some common part of the network. |
This is difficult to debug without example code. |
Can these three classes (stem, identity_classifier and colour_classifier) be integrated in a network? Thank you! For example:
|
Hello,
I have a feature extractor network which is the trunk, which then splits as 2 branches, one as the identity classifier and the other as a color classifier. Now I want to use the gradient reversal on the color classifier output so that the features learned cannot predict the color correctly.
self.classifier_color = torch.nn.Sequential(nn.Linear(2048, 128, bias=False), nn.Linear(128, 7, bias=False), RevGrad())
This is my color branch and I have added the cross-entropy loss on the output from this self.classifier_color
However, I have been getting nan loss after 4-5 iterations. The softmax values are becoming all 0's and one of them 1. It's happening with random labels but that is causing the nan losses.
Am I doing anything incorrect in using this Gradient reversal layer?
Let me know if anything else is to be added and kindly help
Thanks
The text was updated successfully, but these errors were encountered: