-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of loss function #6
Comments
We use a trick to reverse the gradient before the gradient back-propagate from discriminator to the feature extractor. So we do not need to use -1 at the discriminator loss to train the feature extractor. There is no input from the discriminator to the classification loss. According to the Pytorch auto-grad principle. There is no gradient from the classification loss even we back-propagate from the sum of the classification loss and the transfer loss. |
Thanks for your answer, the trick is |
Yes |
How about the tensorflow version? Where is the trick of reversing the gradient? |
The Tensorflow version is under implementation. |
In the DANN model, after inserting a GRL between the generator and the discriminator, the gradient of domain loss w.r.t the feature extractor F will be multiplied by -1. But in CDAN model, the input of discriminator is the tensor product between feature vector and the predicted probability vector. So during backward propagation, the domain loss would have gradient with regard to both feature extractor F and classifier G. May I know how your algorithm computes the gradient of domain loss w.r.t. the predicted probabilities output by classifier G? Will the grl_hook also reverse the gradient of domain loss w.r.t. the classifier G? Thanks a lot! |
In pytorch/loss.py line 22, softmax_output = input_list[1].detach() |
I cannot undstand two things. I appreciate it if you can explain. (1)pytorch/loss.py line 33. |
I also feel strange about problem(1). Do you understand now? |
But generally speaking, the domain loss needs to optimize the feature extraction network G |
Thanks for implementation from Long, and there are two points confusing me
Hoping for your help
The text was updated successfully, but these errors were encountered: