-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple L2 reconstruction loss? #10
Comments
Hello @SilenceMonk, thanks for the interest you show in our work! tl;dr: Using the L2 reconstruction loss on a VAE is sub-optimal for training a VAE and should be avoided as it causes the VAE to train in an over-regularized regime, causing its outputs to be blurry. If over-regularization is actually desirable (to encode less information in the latent space), then changing the output filters to the number of channels in the image (3 if RGB) and changing the loss function to L2 should be enough for training. During inference, there is no sampling process and the logits are the actual image. First and foremost, it should be pointed out that VAE shouldn't be trained with an RGB output layer and an L2 reconstruction loss (MSE). As beautifully explained in section 5.1 (page 12) of this paper, training the VAE with MSE instead of When you consider the ELBO of a (standard) VAE: maximizing the right hand size of the inequality (or equivalently minimizing its opposite) decomposes into a reconstruction loss term To deal with the reconstruction term The decoder's output layer can model any distribution really: Gaussian, Logistic, etc. I will not go too deep into the details of why we tend to use mixture of discretized logistics as the distribution when modeling pixels (or audio) in this comment, but the github answer I shared earlier talks about that. The short answer to "why use a mixture of discretized logistics" is:
Optimizing the MSE instead of performing an explicit MLE on a model distribution however is sub-optimal when maximizing the likelihood of the data under the gaussian distribution. In fact, the reconstruction loss of the ELBO under a decoder modeling a single gaussian distribution can be written as: You can notice that this reconstruction loss is equivalent to the MSE loss under the assumption that the gaussian scale Now, if the over-regularization is still desirable (side note: over-regularization can be achieved using beta-VAE as well), then changing the output filters to the number of channels in the image (3 if RGB) and changing the loss function to L2 should be enough for training. The output layer can be a simple linear layer with 3 filters for RGB. During inference, there is no sampling process and the logits are the actual image. It is worth pointing out that the logits in this case are floating points and should be discretized and clipped to the integer range [0, 255]. If the needs also arises at any time, the output layer and MoL loss can be changed to work with any other distribution that fits the need (Gaussian, Cauchy, etc). I hope this long comment answers your question and also gives some extra information to help reduce the obscurity around DiscMixLogistic! :) Let me know if there are any concerns or extra questions about this topic! |
Wow not expect for such a quick, loooooong and detailed guide on MoL loss! I‘ve certainly learnt a lot from it, and I'll definitely check the github anwser later on. Great thx again! |
Hi there I get a little follow-up question on MoL loss: what if I am dealing with some discrete data where order doesn't matter, like words? Will it be an issue to use MoL loss in this case? |
Hello again @SilenceMonk! In the case where you are dealing with discrete data, then you change the model's output distribution from a Mixture of (discretized) logistics to a categorical (also called Multinoulli) distribution. Lucky for us, the negative log of the Multinoulli distribution is exactly the cross entropy loss function (derivation below). Which means, minimizing the cross entropy loss is equivalent to performing a maximum likelihood estimation on a categorical distribution. So, when dealing with a purely categorical output (where the order doesn't matter), Having a simple softmax output layer + using a cross entropy loss function is sufficient. This is the common practice when dealing with categorical data (as far as I am aware). What follows is the derivation of the cross entropy loss from the Multinoulli PMF:
Hope this answers the question :) |
Wow great thanks again @Rayhane-mamah! I guess I'll go with softmax+cross entropy when dealing with discrete data and see what I'll get. |
Thx for the excellent work! Currently, it seems that we are using some obscure DiscMixLogistic reconstruction loss. Is there any guide on using simple L2 reconstruction loss? Do I need to change the model architecture for that?
The text was updated successfully, but these errors were encountered: