Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of loss function #6

Open
CN-BiGLiu opened this issue Mar 13, 2019 · 10 comments
Open

Implementation of loss function #6

CN-BiGLiu opened this issue Mar 13, 2019 · 10 comments

Comments

@CN-BiGLiu
Copy link

Thanks for implementation from Long, and there are two points confusing me

  1. the total loss is defined as classification loss + transfer loss, which is different from equation(3) classification loss-transfer loss.
  2. the domain discriminator is updated based on the total loss instead of the transfer loss.
    Hoping for your help
@caozhangjie
Copy link
Collaborator

We use a trick to reverse the gradient before the gradient back-propagate from discriminator to the feature extractor. So we do not need to use -1 at the discriminator loss to train the feature extractor.

There is no input from the discriminator to the classification loss. According to the Pytorch auto-grad principle. There is no gradient from the classification loss even we back-propagate from the sum of the classification loss and the transfer loss.

@CN-BiGLiu
Copy link
Author

Thanks for your answer, the trick is x.register_hook(grl_hook(coeff)), is that right?

@caozhangjie
Copy link
Collaborator

Yes

@MeLonJ10
Copy link

How about the tensorflow version? Where is the trick of reversing the gradient?
Thanks a lot!

@sy565612345
Copy link
Collaborator

How about the tensorflow version? Where is the trick of reversing the gradient?
Thanks a lot!

The Tensorflow version is under implementation.
The trick of gradient reversing is in pytorch/network.py line 388. The grl_hook adds a grl layer between the ResNet CNN and the Domain Discriminator, which enables the update of the two adversarial players in one feedforward and backward propagation.

@sxwawa
Copy link

sxwawa commented Jun 5, 2019

In the DANN model, after inserting a GRL between the generator and the discriminator, the gradient of domain loss w.r.t the feature extractor F will be multiplied by -1. But in CDAN model, the input of discriminator is the tensor product between feature vector and the predicted probability vector. So during backward propagation, the domain loss would have gradient with regard to both feature extractor F and classifier G. May I know how your algorithm computes the gradient of domain loss w.r.t. the predicted probabilities output by classifier G? Will the grl_hook also reverse the gradient of domain loss w.r.t. the classifier G? Thanks a lot!

@sy565612345
Copy link
Collaborator

In the DANN model, after inserting a GRL between the generator and the discriminator, the gradient of domain loss w.r.t the feature extractor F will be multiplied by -1. But in CDAN model, the input of discriminator is the tensor product between feature vector and the predicted probability vector. So during backward propagation, the domain loss would have gradient with regard to both feature extractor F and classifier G. May I know how your algorithm computes the gradient of domain loss w.r.t. the predicted probabilities output by classifier G? Will the grl_hook also reverse the gradient of domain loss w.r.t. the classifier G? Thanks a lot!

In pytorch/loss.py line 22, softmax_output = input_list[1].detach()
This detaches G from the domain loss during back-propagation, so the domain loss will not be used to update classifier G.

@xyqfountain
Copy link

I cannot undstand two things. I appreciate it if you can explain. (1)pytorch/loss.py line 33. entropy.register_hook(grl_hook(coeff)) , Why the entropy need this *-1 hook? The grads passed back from the domain discriminator to the feature extractor have been inverted by using x.register_hook(grl_hook(coeff)) , Registering a *-1 hook for the entropy confuses me. (2) I noticed that you use softmax_output=input_list[1].detach() which blocks the grads from the discrininator to the classifier, but the entropy is obtained by loss_func.Entropy(softmax_output) resulting to entropy.requires_grad=True. This means the grads can be back-propagated to the classifier through entropy (am I right?), What is this for?

@buerzlh
Copy link

buerzlh commented Jul 2, 2020

I cannot undstand two things. I appreciate it if you can explain. (1)pytorch/loss.py line 33. entropy.register_hook(grl_hook(coeff)) , Why the entropy need this *-1 hook? The grads passed back from the domain discriminator to the feature extractor have been inverted by using x.register_hook(grl_hook(coeff)) , Registering a *-1 hook for the entropy confuses me. (2) I noticed that you use softmax_output=input_list[1].detach() which blocks the grads from the discrininator to the classifier, but the entropy is obtained by loss_func.Entropy(softmax_output) resulting to entropy.requires_grad=True. This means the grads can be back-propagated to the classifier through entropy (am I right?), What is this for?

I also feel strange about problem(1). Do you understand now?

@buerzlh
Copy link

buerzlh commented Jul 2, 2020

In the DANN model, after inserting a GRL between the generator and the discriminator, the gradient of domain loss w.r.t the feature extractor F will be multiplied by -1. But in CDAN model, the input of discriminator is the tensor product between feature vector and the predicted probability vector. So during backward propagation, the domain loss would have gradient with regard to both feature extractor F and classifier G. May I know how your algorithm computes the gradient of domain loss w.r.t. the predicted probabilities output by classifier G? Will the grl_hook also reverse the gradient of domain loss w.r.t. the classifier G? Thanks a lot!

In pytorch/loss.py line 22, softmax_output = input_list[1].detach()
This detaches G from the domain loss during back-propagation, so the domain loss will not be used to update classifier G.

But generally speaking, the domain loss needs to optimize the feature extraction network G

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants