Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to train ADM-IP.pt on CIFAR10 #25

Open
TT-RAY opened this issue Apr 8, 2024 · 21 comments
Open

How to train ADM-IP.pt on CIFAR10 #25

TT-RAY opened this issue Apr 8, 2024 · 21 comments

Comments

@TT-RAY
Copy link

TT-RAY commented Apr 8, 2024

Thank you to the author for your work,but I'm curious how to train to get ADM-IP on the CiFar-10 dataset. The FID values I obtained during iterations were significantly higher than the ones reported in your paper.My results are at 70k iterations, the FID was 3.66, at 230k iterations the FID was 10.61, at 460k iterations the FID reached 17.17,Below is the script I generated. Are there any questions?so I would appreciate your advice. Thanks.
mpiexec -n 4 python scripts/image_train.py --input_pertub 0.15 \ --data_dir datasets/cifar_train \ --image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \ --attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \ --learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \ --rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 64

@forever208
Copy link
Owner

forever208 commented Apr 8, 2024

hi @TT-RAY, the only difference is that you used 256 batch size where I used 128 in my experiments. But I do not think it affect the FID too much.

Another detail is that, make sure you use the same cifar10 dataset for training and FID evaluation, where I used png images for cifar10 training and, pack these png files into a single npz file for FID evaluation.

The last thing you can check is that, try to replicate the ADM baseline on cifar10, and see if the step-FID curve matches the one I reported in the paper.

@TT-RAY
Copy link
Author

TT-RAY commented Apr 8, 2024

Hello, I generated the cifar10_train.npz and cifar_train datasets using the datasets/cifar10.py file. I also trained the model using the datasets/cifar_train dataset.Are there any other potential issues that could be causing the problem?

@forever208
Copy link
Owner

forever208 commented Apr 8, 2024

Hello, I generated the cifar10_train.npz and cifar_train datasets using the datasets/cifar10.py file. I also trained the model using the datasets/cifar_train dataset.Are there any other potential issues that could be causing the problem?

Your dataset implementations are correct.

Can you take a look at 'issues' where people also asked questions about CIFAR-10 reproduction results and they solved them out finally. You might implement something wrong unintentionally, for example, make sure using ema ckpt for sampling.
Someone even changed a GPU server and got the correct results.

Let me know if you finally figure it out or not

@TT-RAY
Copy link
Author

TT-RAY commented Apr 8, 2024

Hello, here is my sampling code. Could you please take a look and see if there are any issues? Also, I have 4 A800 GPUs, and I'm not sure if different GPU devices can have an impact.
mpiexec -n 4 python /home/caobr/DDPM-IP/scripts/image_sample.py \ --image_size 32 --timestep_respacing 100 \ --model_path /home/caobr/DDPM-IP/mode_loop/ema_0.9999_460000.pt \ --num_channels 128 --num_head_channels 32 --num_res_blocks 3 --attention_resolutions 16,8 \ --resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.3 \ --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --batch_size 256 --num_samples 50000

@forever208
Copy link
Owner

@TT-RAY your sampling code is ok. Our code was tested on V100 and A100.

what I suggest is that first compute the FID using our released ckpt, and see if you can get the correct FID, by doing so, you can make sure your FID computation is right

@TT-RAY
Copy link
Author

TT-RAY commented Apr 9, 2024

@TT-RAY your sampling code is ok. Our code was tested on V100 and A100.

what I suggest is that first compute the FID using our released ckpt, and see if you can get the correct FID, by doing so, you can make sure your FID computation is right

Okay, thank you very much for your suggestion. The FID value of your released ckpt is 2.37,and I will give it another try

@TT-RAY
Copy link
Author

TT-RAY commented Apr 13, 2024

@forever208 Hello, I re-downloaded your code and trained it for 500k steps using 2 V100 GPUs. However, I noticed that at 460k steps, the FID value is 9.8, which is significantly higher than the expected value of 2.3. Here are my settings. Could you please take a look and let me know if there's any issue?
mpiexec -n 2 python scripts/image_train.py --input_pertub 0.15 \ --data_dir datasets/cifar_train \ --image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \ --attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \ --learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \ --rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 64
屏幕截图 2024-04-13 162057

@forever208
Copy link
Owner

@TT-RAY can you first try with --input_pertub 0 which corresponds to ADM baseline, and see what is the FID at 500k steps?
Also, have you computed the FID using our released ckpt for CIFAR-10?

@TT-RAY
Copy link
Author

TT-RAY commented Apr 18, 2024

@forever208 I followed your advice to set the --input_pertub 0, and at 450k steps, I obtained an FID of 7.61. I believe this is much higher than the FID reported in the paper. Additionally, when I used the ADM_IP_015 model released by your team, I obtained an FID of 2.3, which seems to be a normal value. I'm not sure what went wrong.

@forever208
Copy link
Owner

forever208 commented Apr 18, 2024

@TT-RAY since you got the 2.3 FID using the released ckpt, your sampling and FID computation are correct. Then the problem is the training. To be honest, I do not know the exact reason, you use the same code and the same configurations with me, but the FID converges differently. Also --input_pertub 0 corresponds to the ADM baseline, your FID on ADM is also much higher than normal FID on cifar10. DDPM/iDDPM/ADM should get FID 3.0~3.5 on cifar-10.

Can you take a look at your training loss curve? From which, I can compare it with my training loss

@TT-RAY
Copy link
Author

TT-RAY commented Apr 20, 2024

@forever208 The following picture is my training loss curve
屏幕截图 2024-04-20 161303

@forever208
Copy link
Owner

@TT-RAY Hi, I think your loss is much higher than the I got, for example, on ADM baseline, I got 0.055 at 420K steps
image

@TT-RAY
Copy link
Author

TT-RAY commented Apr 22, 2024

@forever208 sorry the loss curve picture is --input_pertub 0.15,The following table is ADM baseline
屏幕截图 2024-04-22 101454

@forever208
Copy link
Owner

@TT-RAY ok, make sense, then your loss is also ok. I currently have no idea why the FID is much higher.

@TT-RAY
Copy link
Author

TT-RAY commented Apr 22, 2024

@forever208 Sure, thank you very much! I'll think of other solutions.

@weigerzan
Copy link

weigerzan commented Jun 23, 2024

@forever208 @TT-RAY Have you got reasonable results? I encountered similar situation where I got FID 4.36 at 180k steps and FID 5.58 at 250k steps with --input_pertub=0.15. I trained the model on a single A800 gpu and I believe I have keep the total batch size to 128. My training setting is as follows:
CUDA_VISIBLE_DEVICES=1 python scripts/image_train.py --input_pertub 0.15 --data_dir cifar_train --image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 --attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 128.
And the sampling is
python scripts/image_sample.py --image_size 32 --timestep_respacing 100 --model_path /tmp/openai-2024-06-23-00-34-46-696387/ema_0.9999_250000.pt --num_channels 128 --num_head_channels 32 --num_res_blocks 3 --attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --batch_size 256 --num_samples 50000
It is expected that FID should at least be lower than 3.5 at around 200k steps, but it seems that the FID diverges with more training steps. Should I decrease the learning rate or do some modifications?

@forever208
Copy link
Owner

@weigerzan @TT-RAY Since both of you encountered the same issue, I will re-run the code on my server and let you know the results on cifar-10

@weigerzan
Copy link

@forever208 Really appreciate that!

@forever208
Copy link
Owner

@weigerzan @TT-RAY sorry, I am too busy these days, I will re-run the experiment next week and will let you know then.

@forever208
Copy link
Owner

@weigerzan @TT-RAY I just finished the retraining of ADM-IP and tested the FID, the results are consistent with my previous implementation. the FID of ADM-IP (trained 450k steps with batch size 128) under 100-step sampling is 2.33

image

I suspect the only possible issue with your implementation is the FID computation. Instead of using the FID provided by the ADM repo, can you try using the pytorch-FID to compute FID?

@TT-RAY
Copy link
Author

TT-RAY commented Jul 24, 2024

@weigerzan @TT-RAY I just finished the retraining of ADM-IP and tested the FID, the results are consistent with my previous implementation. the FID of ADM-IP (trained 450k steps with batch size 128) under 100-step sampling is 2.33

image

I suspect the only possible issue with your implementation is the FID computation. Instead of using the FID provided by the ADM repo, can you try using the pytorch-FID to compute FID?

OK!Thanks a lot! I will have a try these days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants