-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] What is PixelSnail? How to Train it? #4
Comments
Yes, it will generates sample of latent code for VQ-VAE. I checked it can make some samples if you train enough. But you will need to use a quite large model. |
would you mind sharing samples ? Just to get an idea of what to expect |
that's pretty good! thanks for sharing :) |
looks great! would you mind sharing your (hyper-)parameter setting and the resultant accuracy of top/bottom PixelSNAIL for this result? |
Top
Trained 109 epochs with lr 1e-4 and 9 epochs with lr 1e-5, and accuracy was about 48% Bottom
Trained 70 epochs with lr 1e-4 and 4 epochs with lr 1e-5, and accuracy was about 21% I think you can increase res_channel & dropout to match the hyperparameters in the paper, but I can't use that setting because of large amount of memory requirements. Hope this helps. |
it really helps! thanks for the details. |
Thank you for sharing details, can you also share how many GPU you use to train these networks? |
@k-eak I have used 4 V100s with mixed precision training. |
@rosinality Thanks for sharing the details. I am curious about the mixed precision training, do you use some package ? dose mixed training help you increase the batch size? Is it possible to share more on this part? |
@ywang370 I have used NVIDIA apex amp (https://github.com/NVIDIA/apex), with opt_level O1. I think mixed precision training was quite helpful for increasing batch sizes and reducing training times. It is hard to compare directly as GPU is different, but with mixed precision training on V100s is more than 2x faster with 2x batch sizes than FP32 training on P40s. |
@rosinality |
@phongnhhn92 Added simple support form apex at 7a2fbda. |
Hi, after that I tried to train train_pixelsnail.py ( ervery time I got a problem in line 40 in dataset.py it is about no decode) I am trying to solve it but it is not working. thanks a lot. |
How long did it take you per epoch (and how many iterations did you have in an epoch)? I'm finding it takes a considerable amount of time (~7 hours for 34k iterations of batch size 32). |
for which one did you mean ( train_vqvae.py or extract_code.py)!? for train_pixelsnail.py I ma not succeed till now, I have the above problem. I used the original parameters and I change the batch_size to 32 for(train_vqvae.py). |
@zaitoun90 Could you recheck extract_code.py step? I think it might lmdb related problems. |
@rosinality thanks, now it is working. |
Hi one more question, I run everything correctly but still I am getting the samples similar to the original one. I thought that I can generate different images !? Could be!!:: |
@zaitoun90 Do you mean output from ground truth code input? Then it should be similar to input images. To get samples from the model you can use sample.py. |
@rosinality yes, I use sample.py but still, the output similar to the input images. |
@zaitoun90 sample.py doesn't use image inputs. sample.py should generate samples from scratch. |
@rosinality when I check the sample.py I noticed that F.one_hot function seem to be taking too much of time (190 seconds for top-level with batch size:32). I tried to change it with a scatter function to update it according to the previous samples but for some reason, the network is processing much slower now. Do you have any idea why this is happening and have any suggestions on how to improve the sampling time? |
@k-eak Current implementation is quite inefficient, for example one_hot will operate on sequences of 16896 elements per example at top-level. Maybe you can use some kind of caching. I also have tried to implement caching, but I got only 2x improvements... |
@rosinality Thank you for the suggestion. I replaced the one-hot and now I update it after each sample with the scatter function. Although this improved the speed compared one_hot, the network is now taking longer to process and in the end, the improvement is very small. Do you think I am missing something? Here is the changed sampling code: (I removed one_hot function in pixelsnail)
|
@k-eak Did you added torch.cuda.synchronize()? I think speed measurement can be inaccurate because of asynchronous nature of PyTorch. Also speed gain can be small as much of the computation will occur in the rest of the model. |
@rosinality oh my bad, I needed to add torch.cuda.synchronize(). So my method does not change the speed that much and mostly saves a couple of seconds for large batches. I might try adding caching idea from "https://github.com/PrajitR/fast-pixel-cnn/blob/master/fast_pixel_cnn_pp/fast_nn.py" but might take some time to implement it on PyTorch. |
When I use train_pixelsnail.py accuracy immediately hits 1.0 and the loss goes to basically zero after less than 100 iterations. This feels weird to me, what is going on? I've got these settings: amp='O0', batch=12, channel=512, ckpt=None, dropout=0.1, epoch=200, hier='top', lr=0.0001, n_cond_res_block=0, n_out_res_block=5, n_res_block=5, n_res_channel=512 |
I have the same issue. In my case the data_set might be "too simple", this is just my guess... What about your data??? |
No I don't think it is too simple, I am using this on here: |
@Mut1nyJD Ok, I figured it out, in my case the VQVAE training converged to some local minima, so the reconstruction samples where not good and almost identical. I retrained it, playing around with hyperparameters and now everything is fine. |
I've seen this problem, in my case I can reproduce it easily by setting the learning rate, --lr, high. I also get the opposite problem, that the latent goes to zero and it settles into producing a uniform colour output. This is another minima - just don't pass any information through and output a constant. I'm guessing that it is weighting the latent cost too much. I don't know where Overall I don't find that VQ-VAE-2 fits with pixelSNAIL very cleanly, too many things to train independenly and hope they fit together at the end. There has to be a single cost function that would be cleaner. |
@drtonyr |
Hey, with regarding the training time, do you count them in days or just few hours? Because mine takes really long and I think probably days? I'm using 1 32G GPU tesla-smx2. Thank you! |
@rosinality umm then what happens if I input an image? Will it become a whole new image? But this is a vae that reconstructs images? What is the PixelSnail doing? Is it a generator? Thanks! |
Hi! I'm failing to understand the function of PixelSnail. Is it to generate a latent space similar to a GAN?
I trained VQVAE correctly (until the samples were good enought):
python train_vqvae.py ./dataset_path
Then I performed a test to train PixelSnail. Is it correct?
Extracted the codes (I assume that are the encoding for each image on the dataset):
python extract_code.py --ckpt checkpoint/vqvae_241.pt --name small256 ./dataset_path
Then I trained Top hierarchy (about 30 minutes per batch, only trained 1 batch):
python train_pixelsnail.py --hier top --batch 8 small256
Then I trained Bottom hierarchy (about 30 minutes per batch, only trained 1 batch):
python train_pixelsnail.py --hier bottom --batch 8 small256
And finally I sampled:
python sample.py --vqvae checkpoint/vqvae_001.pt --top checkpoint/pixelsnail_top_001.pt --bottom checkpoint/pixelsnail_bottom_001.pt output.png
The output, as expected, is just noise since I only trained 1 batch on Pixelsnail.
If I just keep training PixelSnail will I be able to obtain good samples?
Hardware: NVIDIA 1080Ti
Thank you!
The text was updated successfully, but these errors were encountered: