Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very confused by the discriminator loss #93

Closed
xesdiny opened this issue Aug 7, 2021 · 14 comments
Closed

Very confused by the discriminator loss #93

xesdiny opened this issue Aug 7, 2021 · 14 comments

Comments

@xesdiny
Copy link

xesdiny commented Aug 7, 2021

When training the VQGAN pipeline in FFHQ dataset.
I checked the disc_loss use the function like vanilla_d_loss

def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1. - logits_real))
    loss_fake = torch.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss

But the metric in tensorboard ,the loss is very strangeness!
image

I am confused whether this discriminator loss is really optimized for generator training.

The discriminator loss is joined to the process after the training step reaches 30K. By the way, add the metric of discriminator loss form training starts to the shown in the picture above.
image

@hyakuchiki
Copy link

hyakuchiki commented Aug 7, 2021

A lot of people seems to have the same problem with the discriminator not being trained properly.
#73
Have you looked at the d_weight value on Tensorboard? If it is fluctuating at high values then it might be a problem.
I suspect that if the disc_start parameter is higher, the reconstruction will settle first and the d_weight will be a sensible value. The authors suggest that you train 3-5 epochs without the discriminator in case of ImageNet, so that would mean that disc_start should be several millions? I guess that the discriminator should only be used when the VQVAE is starting to produce alright results.
#31
The default value for disc_start is 10000 in custom_vqgan.yaml, which seems way too low.
I had the same problem, so, I set disc_start to 50000 and disc_weight to 0.2 and I'm getting somewhat better results (Although I'm worried that disc_weight is a bit too low now?).
image
image

@xesdiny
Copy link
Author

xesdiny commented Aug 9, 2021

Emm Yeah!
I understand what you mean is that the discriminator is invalid before the generator reaches the nice benchmark, so the time when the discriminator enters the training phase should be delayed.
The d_weight fraction is used as the weight coefficient of the discriminator to weight the total_loss.
And It It calculates the 2-norm ratio after deriving the parameters of the last layer of the model based on rec_loss and g_loss.

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

The d_weight_step value in yours tensorboard approaching zeros.
And I think this value should be stable at about 1 to guide the generation of the generator.(But in fact, when the value was floating around 1, disc_loss was not decreased.)Maybe I did't understand the meaning behind d_weight correctly.
Emm .. I will adopt your suggestions on this pipeline.
Thx~

A lot of people seems to have the same problem with the discriminator not being trained properly.
#73
Have you looked at the d_weight value on Tensorboard? If it is fluctuating at high values then it might be a problem.
I suspect that if the disc_start parameter is higher, the reconstruction will settle first and the d_weight will be a sensible value. The authors suggest that you train 3-5 epochs without the discriminator in case of ImageNet, so that would mean that disc_start should be several millions? I guess that the discriminator should only be used when the VQVAE is starting to produce alright results.
#31
The default value for disc_start is 10000 in custom_vqgan.yaml, which seems way too low.
I had the same problem, so, I set disc_start to 50000 and disc_weight to 0.2 and I'm getting somewhat better results (Although I'm worried that disc_weight is a bit too low now?).
image
image

@xesdiny xesdiny closed this as completed Aug 18, 2021
@fortunechen
Copy link

Hi, How is your results now? Could you please share your learning from tuning the disc_start and disc_weight ?

Thx

@MaxyLee
Copy link

MaxyLee commented Oct 15, 2021

Succeed to get a good result on CUB dataset by setting disc_start=50,000 and disc_weight=0.2:
Original images:
media_images_train_inputs_22708_b12a2d1c48148354bc98
Reconstructed images:
media_images_train_reconstructions_22708_016c7934338b71486a37

@PanXiebit
Copy link

@MaxyLee congratulations! could you show more setting details? how many examples of your CUB dataset, and how many steps are in one epoch? Exactly, how many epochs do you start the discriminator?

@MaxyLee
Copy link

MaxyLee commented Oct 17, 2021

@MaxyLee congratulations! could you show more setting details? how many examples of your CUB dataset, and how many steps are in one epoch? Exactly, how many epochs do you start the discriminator?

Here is my config:

model:
  base_learning_rate: 4.5e-6
  target: taming.models.vqgan.VQModel
  params:
    embed_dim: 256
    n_embed: 1024
    ddconfig:
      double_z: False
      z_channels: 256
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult: [ 1,1,2,2,4]  # num_down = len(ch_mult)-1
      num_res_blocks: 2
      attn_resolutions: [16]
      dropout: 0.0

    lossconfig:
      target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
      params:
        disc_conditional: False
        disc_in_channels: 3
        disc_start: 50000
        disc_weight: 0.2
        codebook_weight: 1.0

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 5
    num_workers: 8
    train:
      target: taming.data.custom.CustomTrain
      params:
        training_images_list_file: /data/share/data/birds/CUB_200_2011/cub_train.txt
        size: 256
    validation:
      target: taming.data.custom.CustomTest
      params:
        test_images_list_file: /data/share/data/birds/CUB_200_2011/cub_test.txt
        size: 256

I trained this model on CUB train split(8,855 images) using 4 GPUs with approximately 400 steps per epoch. The discriminator therefore started at more than 100 epochs.
Hope it will help

@PanXiebit
Copy link

@MaxyLee thank u very much!!!

@PanXiebit
Copy link

Hi @MaxyLee, I have trained the vqgan with your setting on my own dataset, the discriminator startes at about 100 epochs, and disc_weight is 0.2. However I still faced the problem, the generated quality was alright. But after starting discriminator, it became worse. This is my training curve.

wx

wx2

In fact the generated images are alright without discriminator. In your traning process, do your generated images become much better after gan training?

@MaxyLee
Copy link

MaxyLee commented Oct 19, 2021

Hi @MaxyLee, I have trained the vqgan with your setting on my own dataset, the discriminator startes at about 100 epochs, and disc_weight is 0.2. However I still faced the problem, the generated quality was alright. But after starting discriminator, it became worse. This is my training curve.

wx wx2

In fact the generated images are alright without discriminator. In your traning process, do your generated images become much better after gan training?

Yes, my model performed much better when the discriminator loss was introduced. As shown in the figure, my model could not generate fine-grained images without the discriminator.
media_images_train_reconstructions_15674_181fa5fefa62fa9ffef6
Maybe you can try to train the generator longer before adding d loss and select the best checkpoint.
Below are my training curves:
Screen Shot 2021-10-19 at 2 25 04 PM

@PanXiebit
Copy link

@MaxyLee thank you for your patience and kindness! I will try more experiments.

@kaihe
Copy link

kaihe commented Jan 13, 2022

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits?
image

@MaxyLee
Copy link

MaxyLee commented Jan 13, 2022

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits? image

These are my training curves:
W B Chart 1_13_2022, 10_54_39 PM
W B Chart 1_13_2022, 10_55_18 PM

@kaihe
Copy link

kaihe commented Jan 14, 2022

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits? image

These are my training curves: W B Chart 1_13_2022, 10_54_39 PM W B Chart 1_13_2022, 10_55_18 PM

Thanks very much, that confirm my suspicions: a good discriminator is enough for sharp images, no need for gan equilibrium

@ThisisBillhe
Copy link

I think for a successful discriminator training, logits fake should be negative and logits real should be positive. But I noticed that in the abrove train curves, logits fake and logits real looks always same. Does that mean discriminator is failed and just output same value regardless of input image? @MaxyLee would you also share your training curves of logits? image

Hi, How to solve the problem of logits_real and logits_fake being almost the same?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants