Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] What is PixelSnail? How to Train it? #4

Open
EibrielInv opened this issue Jun 24, 2019 · 35 comments
Open

[Question] What is PixelSnail? How to Train it? #4

EibrielInv opened this issue Jun 24, 2019 · 35 comments

Comments

@EibrielInv
Copy link

EibrielInv commented Jun 24, 2019

Hi! I'm failing to understand the function of PixelSnail. Is it to generate a latent space similar to a GAN?

I trained VQVAE correctly (until the samples were good enought):

python train_vqvae.py ./dataset_path

Then I performed a test to train PixelSnail. Is it correct?

Extracted the codes (I assume that are the encoding for each image on the dataset):

python extract_code.py --ckpt checkpoint/vqvae_241.pt --name small256 ./dataset_path

Then I trained Top hierarchy (about 30 minutes per batch, only trained 1 batch):

python train_pixelsnail.py --hier top --batch 8 small256

Then I trained Bottom hierarchy (about 30 minutes per batch, only trained 1 batch):

python train_pixelsnail.py --hier bottom --batch 8 small256

And finally I sampled:

python sample.py --vqvae checkpoint/vqvae_001.pt --top checkpoint/pixelsnail_top_001.pt --bottom checkpoint/pixelsnail_bottom_001.pt output.png

The output, as expected, is just noise since I only trained 1 batch on Pixelsnail.

output

If I just keep training PixelSnail will I be able to obtain good samples?

Hardware: NVIDIA 1080Ti

Thank you!

@rosinality
Copy link
Owner

Yes, it will generates sample of latent code for VQ-VAE. I checked it can make some samples if you train enough. But you will need to use a quite large model.

@pclucas14
Copy link

would you mind sharing samples ? Just to get an idea of what to expect

@rosinality
Copy link
Owner

sample
Not very nice, but it is from somewhat smaller model than the model in the paper.

@pclucas14
Copy link

that's pretty good! thanks for sharing :)

@1Konny
Copy link

1Konny commented Jun 27, 2019

looks great! would you mind sharing your (hyper-)parameter setting and the resultant accuracy of top/bottom PixelSNAIL for this result?

@rosinality
Copy link
Owner

rosinality commented Jun 27, 2019

Top
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 0
  • n_out_res_block: 5
  • attention: True
  • dropout: 0.1
  • batch size: 63 (1e-4) / 64 (1e-5)

Trained 109 epochs with lr 1e-4 and 9 epochs with lr 1e-5, and accuracy was about 48%

Bottom
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 5
  • cond_res_channel: 512
  • n_out_res_block: 0
  • attention: False
  • dropout: 0.1
  • batch size: 64

Trained 70 epochs with lr 1e-4 and 4 epochs with lr 1e-5, and accuracy was about 21%

I think you can increase res_channel & dropout to match the hyperparameters in the paper, but I can't use that setting because of large amount of memory requirements.

Hope this helps.

@1Konny
Copy link

1Konny commented Jun 27, 2019

it really helps! thanks for the details.

@k-eak
Copy link

k-eak commented Jul 3, 2019

Top
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 0
  • n_out_res_block: 5
  • attention: True
  • dropout: 0.1
  • batch size: 63 (1e-4) / 64 (1e-5)

Trained 109 epochs with lr 1e-4 and 9 epochs with lr 1e-5, and accuracy was about 48%

Bottom
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 5
  • cond_res_channel: 512
  • n_out_res_block: 0
  • attention: False
  • dropout: 0.1
  • batch size: 64

Trained 70 epochs with lr 1e-4 and 4 epochs with lr 1e-5, and accuracy was about 21%

I think you can increase res_channel & dropout to match the hyperparameters in the paper, but I can't use that setting because of large amount of memory requirements.

Hope this helps.

Thank you for sharing details, can you also share how many GPU you use to train these networks?

@rosinality
Copy link
Owner

@k-eak I have used 4 V100s with mixed precision training.

@ywang370
Copy link

ywang370 commented Jul 9, 2019

@k-eak I have used 4 V100s with mixed precision training.

@rosinality Thanks for sharing the details. I am curious about the mixed precision training, do you use some package ? dose mixed training help you increase the batch size? Is it possible to share more on this part?

@rosinality
Copy link
Owner

@ywang370 I have used NVIDIA apex amp (https://github.com/NVIDIA/apex), with opt_level O1. I think mixed precision training was quite helpful for increasing batch sizes and reducing training times. It is hard to compare directly as GPU is different, but with mixed precision training on V100s is more than 2x faster with 2x batch sizes than FP32 training on P40s.

@phongnhhn92
Copy link

@rosinality
Hi I am having 2 P100, is there any improvement if I use apex for training the Pixel SNAIL in general ?
Would you mind sharing the code you have used to enable mixed precision training you have mentioned above using apex ? I can not find where did you use apex in the github repo.

@rosinality
Copy link
Owner

@phongnhhn92 Added simple support form apex at 7a2fbda.

@zaitoun90
Copy link

Hi,
I have train vqvae and I got very similar images. my dataset is 159 images.
then I train extract_code.py (my point here) How many checkpoints should I use in the end!?

after that I tried to train train_pixelsnail.py ( ervery time I got a problem in line 40 in dataset.py it is about no decode)
then i tried to check if the lmdb file has some data or not , i print the env.state and I got this out put
({'psize': 4096, 'depth': 0, 'branch_pages': 0, 'leaf_pages': 0, 'overflow_pages': 0, 'entries':)

I am trying to solve it but it is not working.

thanks a lot.

@karamarieliu
Copy link

How long did it take you per epoch (and how many iterations did you have in an epoch)? I'm finding it takes a considerable amount of time (~7 hours for 34k iterations of batch size 32).

@zaitoun90
Copy link

How long did it take you per epoch (and how many iterations did you have in an epoch)? I'm finding it takes a considerable amount of time (~7 hours for 34k iterations of batch size 32).

for which one did you mean ( train_vqvae.py or extract_code.py)!?
for both of them, is not that much 15 mints. I have a small dataset and I am using 2x gtx 1080 GPU.

for train_pixelsnail.py I ma not succeed till now, I have the above problem.

I used the original parameters and I change the batch_size to 32 for(train_vqvae.py).

@rosinality
Copy link
Owner

@zaitoun90 Could you recheck extract_code.py step? I think it might lmdb related problems.
@karamarieliu Yes, train_pixelsnail.py requires a lot of time as PixelSNAIL model is quite large.

@zaitoun90
Copy link

zaitoun90 commented Aug 20, 2019

@rosinality thanks, now it is working.

@zaitoun90
Copy link

Hi one more question, I run everything correctly but still I am getting the samples similar to the original one. I thought that I can generate different images !? Could be!!::
@rosinality is the samples that you shared is different from the original dataset!?

@rosinality
Copy link
Owner

@zaitoun90 Do you mean output from ground truth code input? Then it should be similar to input images. To get samples from the model you can use sample.py.

@zaitoun90
Copy link

@rosinality yes, I use sample.py but still, the output similar to the input images.
I expect after this long training of vqvae and pxielsnail that I can generate different samples.

@rosinality
Copy link
Owner

rosinality commented Aug 24, 2019

@zaitoun90 sample.py doesn't use image inputs. sample.py should generate samples from scratch.

@k-eak
Copy link

k-eak commented Aug 29, 2019

@rosinality when I check the sample.py I noticed that F.one_hot function seem to be taking too much of time (190 seconds for top-level with batch size:32). I tried to change it with a scatter function to update it according to the previous samples but for some reason, the network is processing much slower now. Do you have any idea why this is happening and have any suggestions on how to improve the sampling time?

@rosinality
Copy link
Owner

@k-eak Current implementation is quite inefficient, for example one_hot will operate on sequences of 16896 elements per example at top-level. Maybe you can use some kind of caching. I also have tried to implement caching, but I got only 2x improvements...

@k-eak
Copy link

k-eak commented Aug 30, 2019

@rosinality Thank you for the suggestion. I replaced the one-hot and now I update it after each sample with the scatter function. Although this improved the speed compared one_hot, the network is now taking longer to process and in the end, the improvement is very small. Do you think I am missing something?

Here is the changed sampling code: (I removed one_hot function in pixelsnail)

row = torch.zeros(batch, 512, *size, dtype=torch.int64).to(device)
row_sample = torch.zeros(batch, *size, dtype=torch.int64).to(device)
cache = {}
for i in tqdm(range(size[0])):
    for j in range(size[1]):
        out, cache = model(row[:, :, : i + 1, :], condition=condition, cache=cache)
        prob = torch.softmax(out[:, :, i, j] / temperature, 1)
        sample = torch.multinomial(prob, 1)
        row[:,:,i,j] = row[:,:,i,j].scatter(1, sample, 1)
        row_sample[:, i, j] = sample.squeeze(-1)
return row_sample

@rosinality
Copy link
Owner

rosinality commented Sep 1, 2019

@k-eak Did you added torch.cuda.synchronize()? I think speed measurement can be inaccurate because of asynchronous nature of PyTorch. Also speed gain can be small as much of the computation will occur in the rest of the model.

@k-eak
Copy link

k-eak commented Sep 1, 2019

@rosinality oh my bad, I needed to add torch.cuda.synchronize(). So my method does not change the speed that much and mostly saves a couple of seconds for large batches. I might try adding caching idea from "https://github.com/PrajitR/fast-pixel-cnn/blob/master/fast_pixel_cnn_pp/fast_nn.py" but might take some time to implement it on PyTorch.

@Mut1nyJD
Copy link

Mut1nyJD commented Sep 6, 2019

When I use train_pixelsnail.py accuracy immediately hits 1.0 and the loss goes to basically zero after less than 100 iterations. This feels weird to me, what is going on?

I've got these settings:

amp='O0', batch=12, channel=512, ckpt=None, dropout=0.1, epoch=200, hier='top', lr=0.0001, n_cond_res_block=0, n_out_res_block=5, n_res_block=5, n_res_channel=512

@Slimco86
Copy link

Slimco86 commented Oct 2, 2019

@Mut1nyJD

D When I use train_pixelsnail.py accuracy immediately hits 1.0 and the loss goes to basically zero after less than 100 iterations. This feels weird to me, what is going on?

I've got these settings:

amp='O0', batch=12, channel=512, ckpt=None, dropout=0.1, epoch=200, hier='top', lr=0.0001, n_cond_res_block=0, n_out_res_block=5, n_res_block=5, n_res_channel=512

I have the same issue. In my case the data_set might be "too simple", this is just my guess... What about your data???

@Mut1nyJD
Copy link

Mut1nyJD commented Oct 2, 2019

@Slimco86

No I don't think it is too simple, I am using this on here:

https://www.mut1ny.com/peoplepose20k

@Slimco86
Copy link

Slimco86 commented Oct 4, 2019

@Mut1nyJD Ok, I figured it out, in my case the VQVAE training converged to some local minima, so the reconstruction samples where not good and almost identical. I retrained it, playing around with hyperparameters and now everything is fine.

@drtonyr
Copy link

drtonyr commented Oct 5, 2019

I've seen this problem, in my case I can reproduce it easily by setting the learning rate, --lr, high.

I also get the opposite problem, that the latent goes to zero and it settles into producing a uniform colour output. This is another minima - just don't pass any information through and output a constant.

I'm guessing that it is weighting the latent cost too much. I don't know where latent_loss_weight = 0.25 comes from. I reduced the factor from 0.25 to 0.05 (complete guess) and it seemed to fix the problem.

Overall I don't find that VQ-VAE-2 fits with pixelSNAIL very cleanly, too many things to train independenly and hope they fit together at the end. There has to be a single cost function that would be cleaner.

@Mut1nyJD
Copy link

Mut1nyJD commented Oct 9, 2019

@drtonyr
Okay thanks I was using default -lr setting but I'll try again with your suggestion as soon as I find some time in my training slot :). Meaning decreasing the latent_loss and if that does not help the lr .
But I agree this combo VQ-VAE-2 and pixelSnail feels suboptimal. Even though VQ-VAE-2 on itself does provide good reconstruction

@LinfengLiu98
Copy link

@zaitoun90 Could you recheck extract_code.py step? I think it might lmdb related problems.
@karamarieliu Yes, train_pixelsnail.py requires a lot of time as PixelSNAIL model is quite large.

Hey, with regarding the training time, do you count them in days or just few hours? Because mine takes really long and I think probably days? I'm using 1 32G GPU tesla-smx2. Thank you!

@easonoob
Copy link

easonoob commented May 5, 2023

@zaitoun90 sample.py doesn't use image inputs. sample.py should generate samples from scratch.

@rosinality umm then what happens if I input an image? Will it become a whole new image? But this is a vae that reconstructs images? What is the PixelSnail doing? Is it a generator? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests