Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge hf upstream #6

Closed
wants to merge 276 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
276 commits
Select commit Hold shift + click to select a range
a784be2
Allow resolutions that are not multiples of 64 (#505)
jachiam Sep 30, 2022
877bec8
refactor: update ldm-bert `config.json` url closes #675 (#680)
ryanrussell Sep 30, 2022
daa2205
[docs] fix table in fp16.mdx (#683)
NouamaneTazi Sep 30, 2022
bb0f2a0
Update README.md
patrickvonplaten Sep 30, 2022
552b967
Update README.md
patrickvonplaten Sep 30, 2022
b2cfc7a
Fix slow tests (#689)
NouamaneTazi Sep 30, 2022
5156acc
Fix BibText citation (#693)
osanseviero Oct 1, 2022
2558977
Add callback parameters for Stable Diffusion pipelines (#521)
jamestiotio Oct 2, 2022
14f4af8
[dreambooth] fix applying clip_grad_norm_ (#686)
patil-suraj Oct 3, 2022
500ca5a
Forgot to add the OG!
patrickvonplaten Oct 3, 2022
249b36c
Flax: add shape argument to `set_timesteps` (#690)
pcuenca Oct 3, 2022
7d0ba59
Fix type annotations on StableDiffusionPipeline.__call__ (#682)
tasercake Oct 3, 2022
688031c
Fix import with Flax but without PyTorch (#688)
pcuenca Oct 3, 2022
b35bac4
[Support PyTorch 1.8] Remove inference mode (#707)
patrickvonplaten Oct 3, 2022
1070e1a
[CI] Speed up slow tests (#708)
anton-l Oct 3, 2022
f1484b8
[Utils] Add deprecate function and move testing_utils under utils (#659)
patrickvonplaten Oct 3, 2022
4ff4d4d
Checkpoint conversion script from Diffusers => Stable Diffusion (Comp…
jachiam Oct 4, 2022
f1b9ee7
[Docs] fix docstring for issue #709 (#710)
kashif Oct 4, 2022
09859a3
Update schedulers README.md (#694)
tmabraham Oct 4, 2022
4d1cce2
add accelerate to load models with smaller memory footprint (#361)
piEsposito Oct 4, 2022
7e92c5b
Fix typos (#718)
shirayu Oct 4, 2022
5ac1f61
Add an argument "negative_prompt" (#549)
shirayu Oct 4, 2022
215bb40
Fix import if PyTorch is not installed (#715)
pcuenca Oct 4, 2022
6b22192
Remove comments no longer appropriate (#716)
pcuenca Oct 4, 2022
14b9754
[train_unconditional] fix applying clip_grad_norm_ (#721)
patil-suraj Oct 4, 2022
7265dd8
renamed x to meaningful variable in resnet.py (#677)
i-am-epic Oct 4, 2022
a8a3a20
[Tests] Add accelerate to testing (#729)
patrickvonplaten Oct 5, 2022
08d4fb6
[dreambooth] Using already created `Path` in dataset (#681)
DrInfiniteExplorer Oct 5, 2022
b9eea06
Include CLIPTextModel parameters in conversion (#695)
kanewallmann Oct 5, 2022
60c9634
Avoid negative strides for tensors (#717)
shirayu Oct 5, 2022
726aba0
[Pytorch] pytorch only timesteps (#724)
kashif Oct 5, 2022
6b09f37
[Scheduler design] The pragmatic approach (#719)
anton-l Oct 5, 2022
3dcc75c
Removing `autocast` for `35-25% speedup`. (`autocast` considered harm…
Narsil Oct 5, 2022
78744b6
No more use_auth_token=True (#733)
patrickvonplaten Oct 5, 2022
19e559d
remove use_auth_token from remaining places (#737)
patil-suraj Oct 5, 2022
5493524
Replace messages that have empty backquotes (#738)
pcuenca Oct 5, 2022
4deb16e
[Docs] Advertise fp16 instead of autocast (#740)
patrickvonplaten Oct 5, 2022
916754e
make style
patrickvonplaten Oct 5, 2022
367a671
remove use_auth_token from for TI test (#747)
patil-suraj Oct 6, 2022
c119dc4
allow multiple generations per prompt (#741)
patil-suraj Oct 6, 2022
df9c070
Add back-compatibility to LMS timesteps (#750)
anton-l Oct 6, 2022
3383f77
update the clip guided PR according to the new API (#751)
patil-suraj Oct 6, 2022
6c64741
Raise an error when moving an fp16 pipeline to CPU (#749)
anton-l Oct 6, 2022
0883968
Better steps deprecation for LMS (#753)
anton-l Oct 6, 2022
f3128c8
Actually fix the grad ckpt test (#734)
patil-suraj Oct 6, 2022
d9c449e
Custome Pipelines (#744)
patrickvonplaten Oct 6, 2022
6613a8c
make CI happy
patrickvonplaten Oct 6, 2022
9c9462f
Python 3.7 doesn't like keys() + keys()
patrickvonplaten Oct 6, 2022
2e209c3
[v0.4.0] Temporarily remove Flax modules from the public API (#755)
anton-l Oct 6, 2022
4581f14
Update clip_guided_stable_diffusion.py
patil-suraj Oct 6, 2022
3b1d2ca
Release: v0.4.0
anton-l Oct 6, 2022
0fe59b6
Merge remote-tracking branch 'origin/main'
anton-l Oct 6, 2022
c15cda0
Bump to v0.4.1.dev0
anton-l Oct 6, 2022
970e306
Revert "[v0.4.0] Temporarily remove Flax modules from the public API …
anton-l Oct 6, 2022
435433c
Update clip_guided_stable_diffusion.py
patil-suraj Oct 6, 2022
737195d
Created using Colaboratory
patil-suraj Oct 6, 2022
9531150
Bump to v0.5.0.dev0
anton-l Oct 6, 2022
2fa55fc
Merge remote-tracking branch 'origin/main'
anton-l Oct 6, 2022
ae672d5
[Tests] Lower required memory for clip guided and fix super edge-case…
patrickvonplaten Oct 6, 2022
d3f1a4c
Revert "Bump to v0.5.0.dev0"
anton-l Oct 6, 2022
fdfa7c8
Change fp16 error to warning (#764)
apolinario Oct 7, 2022
91ddd2a
Release: v0.4.1
patrickvonplaten Oct 7, 2022
9a95414
Bump to v0.5.0dev0
patrickvonplaten Oct 7, 2022
c93a8cc
remove bogus folder
patrickvonplaten Oct 7, 2022
7258dc4
remove bogus folder no.2
patrickvonplaten Oct 7, 2022
906e410
Fix push_to_hub for dreambooth and textual_inversion (#748)
YaYaB Oct 7, 2022
75bb6d2
Fix ONNX conversion script opset argument type (#739)
justinchuby Oct 7, 2022
e0fece2
Add final latent slice checks to SD pipeline intermediate state tests…
jamestiotio Oct 7, 2022
cb0bf0b
fix(DDIM scheduler): use correct dtype for noise (#742)
keturn Oct 7, 2022
ec831b6
[schedulers] hanlde dtype in add_noise (#767)
patil-suraj Oct 7, 2022
92d7086
[img2img, inpainting] fix fp16 inference (#769)
patil-suraj Oct 7, 2022
f3983d1
[Tests] Fix tests (#774)
patrickvonplaten Oct 7, 2022
5af6eed
debug an exception (#638)
LowinLi Oct 10, 2022
a73f8b7
Clean up resnet.py file (#780)
Oct 10, 2022
feaa732
add sigmoid betas (#777)
Oct 10, 2022
fab1752
[Low CPU memory] + device map (#772)
patrickvonplaten Oct 10, 2022
22963ed
Fix gradient checkpointing test (#797)
patrickvonplaten Oct 10, 2022
71ca10c
fix typo docstring in unet2d (#798)
Oct 10, 2022
81bdbb5
DreamBooth DeepSpeed support for under 8 GB VRAM training (#735)
Ttl Oct 10, 2022
797b290
support bf16 for stable diffusion (#792)
patil-suraj Oct 11, 2022
66a5279
stable diffusion fine-tuning (#356)
patil-suraj Oct 11, 2022
a124204
Flax: Trickle down `norm_num_groups` (#789)
akash5474 Oct 11, 2022
e895952
Eventually preserve this typo? :) (#804)
spezialspezial Oct 11, 2022
757babf
Fix indentation in the code example (#802)
osanseviero Oct 11, 2022
24b8b5c
`mps`: Alternative implementation for `repeat_interleave` (#766)
pcuenca Oct 11, 2022
c1b6ea3
Update img2img.mdx
patrickvonplaten Oct 11, 2022
6bc1178
[Img2Img] Fix batch size mismatch prompts vs. init images (#793)
patrickvonplaten Oct 12, 2022
966e2fc
Minor package fixes (#809)
anton-l Oct 12, 2022
db47b1e
[Dummy imports] Better error message (#795)
patrickvonplaten Oct 12, 2022
679c77f
Add diffusers version and pipeline class to the Hub UA
anton-l Oct 12, 2022
80be074
Merge remote-tracking branch 'origin/main'
anton-l Oct 12, 2022
9659863
Revert an accidental commit
anton-l Oct 12, 2022
5afc2b6
add or fix license formatting in models directory (#808)
Oct 12, 2022
008b608
[train_text2image] Fix EMA and make it compatible with deepspeed. (#813)
patil-suraj Oct 12, 2022
60c384b
Fix fine-tuning compatibility with deepspeed (#816)
pink-red Oct 12, 2022
323a9e1
Add diffusers version and pipeline class to the Hub UA (#814)
anton-l Oct 12, 2022
f1d4289
[Flax] Add test (#824)
patrickvonplaten Oct 13, 2022
0a09af2
update flax scheduler API (#822)
patil-suraj Oct 13, 2022
e001fed
Fix dreambooth loss type with prior_preservation and fp16 (#826)
anton-l Oct 13, 2022
26c7df5
Fix type mismatch error, add tests for negative prompts (#823)
anton-l Oct 13, 2022
e713346
Give more customizable options for safety checker (#815)
patrickvonplaten Oct 13, 2022
78db11d
Flax safety checker (#825)
pcuenca Oct 13, 2022
7c22626
Align PT and Flax API - allow loading checkpoint from PyTorch configs…
patrickvonplaten Oct 13, 2022
1d51224
[Flax] Complete tests (#828)
patrickvonplaten Oct 13, 2022
0679d09
Release: 5.0.0 (#830)
anton-l Oct 13, 2022
effe9d6
[FlaxStableDiffusionPipeline] fix bug when nsfw is detected (#832)
patil-suraj Oct 13, 2022
e48ca0f
Release 0 5 1 (#833)
patrickvonplaten Oct 13, 2022
d3eb3b3
[Community] One step unet (#840)
patrickvonplaten Oct 14, 2022
b8c4d58
Remove unneeded use_auth_token (#839)
osanseviero Oct 14, 2022
52394b5
Bump to 0.6.0.dev0 (#831)
anton-l Oct 14, 2022
1d3234c
Remove the last of ["sample"] (#842)
anton-l Oct 14, 2022
93a81a3
Fix Flax pipeline: width and height are ignored #838 (#848)
camenduru Oct 14, 2022
2b7d4a5
[DeviceMap] Make sure stable diffusion can be loaded from older trans…
patrickvonplaten Oct 16, 2022
765a446
Update README.md
patrickvonplaten Oct 17, 2022
5b94450
Update README.md
patrickvonplaten Oct 17, 2022
ee9875e
Add Stable Diffusion Interpolation Example (#862)
nateraw Oct 17, 2022
ad0e9ac
Update README.md
patrickvonplaten Oct 17, 2022
146419f
All in one Stable Diffusion Pipeline (#821)
patrickvonplaten Oct 17, 2022
ed6c61c
Fix small community pipeline import bug and finish README (#869)
patrickvonplaten Oct 17, 2022
52e8fdb
Update README.md
patrickvonplaten Oct 17, 2022
4dce374
Fix training push_to_hub (unconditional image generation): models wer…
pcuenca Oct 17, 2022
dff91ee
Fix table in community README.md (#879)
nateraw Oct 17, 2022
fd26624
Add generic inference example to community pipeline readme (#874)
apolinario Oct 17, 2022
627ad6e
Rename frame filename in interpolation community example (#881)
nateraw Oct 17, 2022
cca59ce
Add Apple M1 tests (#796)
anton-l Oct 17, 2022
100e094
Fix autoencoder test (#886)
pcuenca Oct 17, 2022
728a3f3
Rename StableDiffusionOnnxPipeline -> OnnxStableDiffusionPipeline (#887)
anton-l Oct 18, 2022
a3efa43
Fix DDIM on Windows not using int64 for timesteps (#819)
hafriedlander Oct 18, 2022
fbe807b
[dreambooth] allow fine-tuning text encoder (#883)
patil-suraj Oct 18, 2022
a9908ec
Stable Diffusion image-to-image and inpaint using onnx. (#552)
zledas Oct 18, 2022
8eb9d97
Improve ONNX img2img numpy handling, temporarily fix the tests (#899)
anton-l Oct 19, 2022
bd21607
make fix copies
patrickvonplaten Oct 19, 2022
6ea8360
[Stable Diffusion Inpainting] Deprecate inpainting pipeline in favor …
patrickvonplaten Oct 19, 2022
83b696e
[Communit Pipeline] Make sure "mega" uses correct inpaint pipeline (#…
patrickvonplaten Oct 19, 2022
b35d88c
Stable diffusion inpainting. (#904)
patil-suraj Oct 19, 2022
4655712
finish tests (#909)
patrickvonplaten Oct 19, 2022
89d1249
ONNX supervised inpainting (#906)
anton-l Oct 19, 2022
8124863
Initial docs update for new in-painting pipeline (#910)
pcuenca Oct 19, 2022
ad9d7ce
Release: 0.6.0
anton-l Oct 19, 2022
2a0c823
[Community Pipelines] Long Prompt Weighting Stable Diffusion Pipeline…
SkyTNT Oct 19, 2022
83f8a5f
[Stable Diffusion] Add components function (#889)
patrickvonplaten Oct 20, 2022
4a76e5d
[PNDM Scheduler] Make sure list cannot grow forever (#882)
patrickvonplaten Oct 20, 2022
db19a9d
[DiffusionPipeline.from_pretrained] add warning when passing unused k…
patrickvonplaten Oct 20, 2022
ce7d966
DOC Dreambooth Add --sample_batch_size=1 to the 8 GB dreambooth examp…
leszekhanusz Oct 20, 2022
a5eb7f4
[Examples] add speech to image pipeline example (#897)
MikailINTech Oct 20, 2022
7674a36
[dreambooth] dont use safety check when generating prior images (#922)
patil-suraj Oct 20, 2022
4bf675f
Dreambooth class image generation: using unique names to avoid overwr…
leszekhanusz Oct 20, 2022
8be4850
fix test_components (#928)
patil-suraj Oct 20, 2022
6f6eef7
Fix Compatibility with Nvidia NGC Containers (#919)
tasercake Oct 20, 2022
ba74a8b
[Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffus…
SkyTNT Oct 20, 2022
cc36f2e
Bump the version to 0.7.0.dev0 (#912)
anton-l Oct 20, 2022
32bf4fd
Introduce the copy mechanism (#924)
anton-l Oct 20, 2022
25dfd0f
[Tests] Move stable diffusion into their own files (#936)
patrickvonplaten Oct 21, 2022
dec18c8
[Flax] dont warn for bf16 weights (#923)
patil-suraj Oct 21, 2022
31af4d1
Support LMSDiscreteScheduler in LDMPipeline (#891)
mkshing Oct 21, 2022
2fdd094
Wildcard stable diffusion pipeline (#900)
shyamsn97 Oct 21, 2022
9bca402
[MPS] fix mps failing tests (#934)
kashif Oct 22, 2022
2d35f67
fix a small typo in pipeline_ddpm.py (#948)
chenguolin Oct 24, 2022
2c82e0c
Reorganize pipeline tests (#963)
anton-l Oct 24, 2022
8aac1f9
v1-5 docs updates (#921)
apolinario Oct 24, 2022
2fb8faf
add community pipeline docs; add minimal text to some empty doc pages…
Oct 24, 2022
8204415
Fix typo: `torch_type` -> `torch_dtype` (#972)
pcuenca Oct 25, 2022
6e099e2
add num_inference_steps arg to DDPM (#935)
tmabraham Oct 25, 2022
38ae5a2
Add Composable diffusion to community pipeline examples (#951)
MarkRich Oct 25, 2022
240abdd
[Flax] added broadcast_to_shape_from_left helper and Scheduler tests …
kashif Oct 25, 2022
28b134e
[Tests] Fix `mps` reproducibility issue when running with pytest-xdis…
anton-l Oct 25, 2022
3d02c92
mps changes for PyTorch 1.13 (#926)
pcuenca Oct 25, 2022
0b42b07
[Onnx] support half-precision and fix bugs for onnx pipelines (#932)
SkyTNT Oct 25, 2022
88fa6b7
[Dance Diffusion] Add dance diffusion (#803)
patrickvonplaten Oct 25, 2022
365ff8f
[Dance Diffusion] FP16 (#980)
patrickvonplaten Oct 25, 2022
59f0ce8
[Dance Diffusion] Better naming (#981)
patrickvonplaten Oct 25, 2022
e2243de
Fix typo in documentation title (#975)
echarlaix Oct 25, 2022
4b9f589
Add --pretrained_model_name_revision option to train_dreambooth.py (#…
shirayu Oct 25, 2022
0343d8f
Do not use torch.float64 on the mps device (#942)
pcuenca Oct 26, 2022
d9cfe32
CompVis -> diffusers script - allow converting from merged checkpoint…
patrickvonplaten Oct 26, 2022
d7d6841
fix a bug in the new version (#957)
xiaohu2015 Oct 26, 2022
cc43608
Fix typos (#978)
shirayu Oct 26, 2022
2f0fcf4
Add missing import (#979)
juliensimon Oct 26, 2022
b2e2d14
minimal stable diffusion GPU memory usage with accelerate hooks (#850)
piEsposito Oct 26, 2022
bd06dd0
[inpaint pipeline] fix bug for multiple prompts inputs (#959)
xiaohu2015 Oct 26, 2022
8332c1a
Enable multi-process DataLoader for dreambooth (#950)
skirsten Oct 26, 2022
d3d22ce
Small modification to enable usage by external scripts (#956)
briancw Oct 26, 2022
a23ad87
[Flax] Add Textual Inversion (#880)
duongna21 Oct 26, 2022
1d04e1b
Continuation of #942: additional float64 failure (#996)
pcuenca Oct 27, 2022
e92a603
fix dreambooth script. (#1017)
patil-suraj Oct 27, 2022
3be9fa9
[Accelerate model loading] Fix meta device and super low memory usage…
patrickvonplaten Oct 27, 2022
abe0582
[Flax] Add finetune Stable Diffusion (#999)
duongna21 Oct 27, 2022
4623f09
[DreamBooth] Set train mode for text encoder (#1012)
duongna21 Oct 27, 2022
90f91ad
[Flax] Add DreamBooth (#1001)
duongna21 Oct 27, 2022
fbcc383
Deprecate `init_git_repo`, refactor `train_unconditional.py` (#1022)
anton-l Oct 27, 2022
52f2128
update readme for flax examples (#1026)
patil-suraj Oct 27, 2022
eceeebd
Update train_dreambooth.py
patil-suraj Oct 27, 2022
939ec17
Probably nicer to specify dependency on tensorboard in the training e…
lukovnikov Oct 27, 2022
a6314a8
Add `--dataloader_num_workers` to the DDPM training example (#1027)
anton-l Oct 27, 2022
de00c63
Document sequential CPU offload method on Stable Diffusion pipeline (…
piEsposito Oct 27, 2022
fb38bb1
Support grayscale images in `numpy_to_pil` (#1025)
anton-l Oct 27, 2022
1e07b6b
[Flax SD finetune] Fix dtype (#1038)
duongna21 Oct 28, 2022
ab079f2
fix `F.interpolate()` for large batch sizes (#1006)
NouamaneTazi Oct 28, 2022
a80480f
[Tests] Improve unet / vae tests (#1018)
patrickvonplaten Oct 28, 2022
d2d9764
[Tests] Speed up slow tests (#1040)
patrickvonplaten Oct 28, 2022
8d6487f
Fix some failing tests (#1041)
patrickvonplaten Oct 28, 2022
c4ef1ef
[Tests] Better prints (#1043)
patrickvonplaten Oct 28, 2022
d37f08d
[Tests] no random latents anymore (#1045)
patrickvonplaten Oct 28, 2022
cbbb293
hot fix
patrickvonplaten Oct 28, 2022
ea01a4c
fix
patrickvonplaten Oct 28, 2022
a7ae808
increase tolerance
patrickvonplaten Oct 28, 2022
81b6fbf
higher precision for vae
patrickvonplaten Oct 28, 2022
6b185b6
Update training and fine-tuning docs (#1020)
pcuenca Oct 28, 2022
fc0ca47
Fix speedup ratio in fp16.mdx (#837)
mwbyeon Oct 29, 2022
12fd073
clean incomplete pages (#1008)
Oct 29, 2022
1fc2088
Add seed resizing to community pipelines (#1011)
MarkRich Oct 29, 2022
a59f999
Tests: upgrade PyTorch cuda to 11.7 to fix examples tests. (#1048)
pcuenca Oct 29, 2022
95414bd
Experimental: allow fp16 in `mps` (#961)
pcuenca Oct 29, 2022
8e4fd68
Move safety detection to model call in Flax safety checker (#1023)
jonatanklosko Oct 30, 2022
707b868
fix slow test
patrickvonplaten Oct 31, 2022
82d56cf
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Oct 31, 2022
1606eb9
Fix pipelines user_agent, ignore CI requests (#1058)
anton-l Oct 31, 2022
e4d264e
[GitBot] Automatically close issues after inactivitiy (#1079)
patrickvonplaten Oct 31, 2022
bf7b0bc
Allow `safety_checker` to be `None` when using CPU offload (#1078)
pcuenca Oct 31, 2022
a1ea8c0
k-diffusion-euler (#1019)
hlky Oct 31, 2022
c18941b
[Better scheduler docs] Improve usage examples of schedulers (#890)
patrickvonplaten Oct 31, 2022
010bc4e
incorrect model id
patrickvonplaten Oct 31, 2022
17c2c06
[Tests] Fix slow tests (#1087)
patrickvonplaten Oct 31, 2022
888468d
Remove nn sequential (#1086)
patrickvonplaten Oct 31, 2022
7fb4b88
Remove some unused parameter in CrossAttnUpBlock2D (#1034)
LaurentMazare Oct 31, 2022
ab303e8
Merge branch 'main' of https://github.com/Oneflow-Inc/diffusers into …
jackalcooper Nov 1, 2022
b5118dd
update pipe
jackalcooper Nov 1, 2022
81da7cf
update pndm
jackalcooper Nov 1, 2022
9659489
rename
jackalcooper Nov 1, 2022
8cd3276
rename
jackalcooper Nov 1, 2022
653870f
update resnet
jackalcooper Nov 1, 2022
bcf37fb
update attention
jackalcooper Nov 1, 2022
cf07252
update
jackalcooper Nov 1, 2022
d2c16a2
update safety checker
jackalcooper Nov 1, 2022
2c2b3d4
Merge branch 'main' of https://github.com/Oneflow-Inc/diffusers into …
jackalcooper Nov 1, 2022
b770126
rm todo
jackalcooper Nov 1, 2022
6b967ba
tryfix
jackalcooper Nov 1, 2022
923b5d9
fix
jackalcooper Nov 1, 2022
eedeacb
rm
jackalcooper Nov 1, 2022
498f22b
fix
jackalcooper Nov 1, 2022
9807fa3
fix
jackalcooper Nov 1, 2022
cbc8df3
add
jackalcooper Nov 2, 2022
85ff49f
refine log
jackalcooper Nov 2, 2022
f6ac84a
set env
jackalcooper Nov 3, 2022
f429954
Debug sd conv gn geglu (#7)
jackalcooper Nov 7, 2022
dc01404
Add arg compile_unet (#17)
jackalcooper Nov 8, 2022
1aa67aa
stable-diffusion support multiple shapes (#13)
yuantailing Nov 9, 2022
d6c7a54
Support DPMsolver (#21)
jackalcooper Nov 9, 2022
e866174
Image to image (#25)
jackalcooper Nov 10, 2022
c057826
Merge branch 'oneflow-fork' of https://github.com/Oneflow-Inc/diffuse…
jackalcooper Nov 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
stable diffusion fine-tuning (#356)
* begin text2image script

* loading the datasets, preprocessing & transforms

* handle input features correctly

* add gradient checkpointing support

* fix output names

* run unet in train mode not text encoder

* use no_grad instead of freezing params

* default max steps None

* pad to longest

* don't pad when tokenizing

* fix encode on multi gpu

* fix stupid bug

* add random flip

* add ema

* fix ema

* put ema on cpu

* improve EMA model

* contiguous_format

* don't warp vae and text encode in accelerate

* remove no_grad

* use randn_like

* fix resize

* improve few things

* log epoch loss

* set log level

* don't log each step

* remove max_length from collate

* style

* add report_to option

* make scale_lr false by default

* add grad clipping

* add an option to use 8bit adam

* fix logging in multi-gpu, log every step

* more comments

* remove eval for now

* adress review comments

* add requirements file

* begin readme

* begin readme

* fix typo

* fix push to hub

* populate readme

* update readme

* remove use_auth_token from the script

* address some review comments

* better mixed precision support

* remove redundant to

* create ema model early

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* better description for train_data_dir

* add diffusers in requirements

* update dataset_name_mapping

* update readme

* add inference example

Co-authored-by: anton-l <anton@huggingface.co>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
3 people authored Oct 11, 2022
commit 66a5279a9422962b1cff3ad0e5747e8903ae067b
101 changes: 101 additions & 0 deletions examples/text_to_image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Stable Diffusion text-to-image fine-tuning

The `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.

___Note___:

___This script is experimental. The script fine-tunes the whole model and often times the model overifits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___


## Running locally
### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

```bash
pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

### Pokemon example

You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.

You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).

Run the following command to authenticate your token

```bash
huggingface-cli login
```

If you have already cloned the repo, then you won't need to go through these steps.

<br>

#### Hardware
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```


To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).
If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"

accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
```

Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`


```python
from diffusers import StableDiffusionPipeline

model_path = "path_to_saved_model"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")

image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")
```
7 changes: 7 additions & 0 deletions examples/text_to_image/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
diffusers==0.4.1
accelerate
torchvision
transformers>=4.21.0
ftfy
tensorboard
modelcards
621 changes: 621 additions & 0 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,621 @@
import argparse
import copy
import logging
import math
import os
import random
from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer


logger = get_logger(__name__)


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-model-finetuned",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank

# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")

return args


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}


class EMAModel:
"""
Exponential Moving Average of models weights
"""

def __init__(
self,
model,
decay=0.9999,
device=None,
):
self.averaged_model = copy.deepcopy(model).eval()
self.averaged_model.requires_grad_(False)

self.decay = decay

if device is not None:
self.averaged_model = self.averaged_model.to(device=device)

self.optimization_step = 0

def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)

@torch.no_grad()
def step(self, new_model):
ema_state_dict = self.averaged_model.state_dict()

self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)

for key, param in new_model.named_parameters():
if isinstance(param, dict):
continue
try:
ema_param = ema_state_dict[key]
except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
ema_state_dict[key] = ema_param

param = param.clone().detach().to(ema_param.dtype).to(ema_param.device)

if param.requires_grad:
ema_state_dict[key].sub_(self.decay * (ema_param - param))
else:
ema_state_dict[key].copy_(param)

for key, param in new_model.named_buffers():
ema_state_dict[key] = param

self.averaged_model.load_state_dict(ema_state_dict, strict=False)
torch.cuda.empty_cache()


def main():
args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
)

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)

# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)

# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")

if args.use_ema:
ema_unet = EMAModel(unet)

# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()

if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)

optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW

optimizer = optimizer_cls(
unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
)

# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)

# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
captions = []
for caption in examples[caption_column]:
if isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
input_ids = inputs.input_ids
return input_ids

train_transforms = transforms.Compose(
[
transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images]
examples["input_ids"] = tokenize_captions(examples)

return examples

with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example["input_ids"] for example in examples]
padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
return {
"pixel_values": pixel_values,
"input_ids": padded_tokens.input_ids,
"attention_mask": padded_tokens.attention_mask,
}

train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)

unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)

weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16

# Move text_encode and vae to gpu.
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
text_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)

# Move the ema_unet to gpu.
ema_unet.averaged_model.to(accelerator.device)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0

for epoch in range(args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
latents = latents * 0.18215

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]

# Predict the noise residual and compute loss
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="mean")

# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps

# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet)
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0

logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=accelerator.unwrap_model(ema_unet.averaged_model if args.use_ema else unet),
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)

if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

accelerator.end_training()


if __name__ == "__main__":
main()