-
Notifications
You must be signed in to change notification settings - Fork 921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adafactor fused backward pass and optimizer step, lowers SDXL (@ 1024 resolution) VRAM usage to BF16(10GB)/FP32(16.4GB) #1259
Adafactor fused backward pass and optimizer step, lowers SDXL (@ 1024 resolution) VRAM usage to BF16(10GB)/FP32(16.4GB) #1259
Conversation
wow nice. yes i use it with onetrainer and 10.3 GB VRAM required for SDXL training even published a video Full Stable Diffusion SD & XL Fine Tuning Tutorial With OneTrainer On Windows & Cloud - Zero To Hero |
Interesting! Do you know if this is a pure memory saving, or if it changes the gradients at all? Although even if it just saves memory with no consequence, that would be great, especially if Stable Diffusion 3 is round the corner. I'm interested to see if fp32 training of SDXL has any benefits over bf16. At the very least, it'll stop the text encoder getting quantized to bf16 as the first step, which makes people's faces look a bit weird for a few iterations til the model rebalances itself to deal with the lost precision. But that's just a temporary effect. I hacked my own code to leave the text encoder as fp32 already, even when using full bf16 training to avoid this, but I wouldn't have to do that anymore with this update of yours. @FurkanGozukara, did you notice any quality improvements to the final results, or just reduced memory? |
Pure memory, see: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
Overall the FP32 is much better. There are known training issues with BF16, and techniques like 'stochastic rounding' that 'attempt to fix' the problems with training in BF16 by shifting the results nearer to that of having trained in FP32 instead. See here: https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf , check out pages 53 (starting at setion C.2) and 54 , and also this paper if you want to read about BF16 and 'stochastic rounding' a bit more: https://arxiv.org/pdf/2010.06192.pdf In a future PR to this repo the 'stochastic rounding' can be added, which should allow those that can 'only' train in BF16 (due to VRAM limits, which this PR now opens the window up of letting even 12GB GPUs train 'normal' SDXL) to have their results nearer to the results attained when training in FP32, which I have tested on my own end and the results are much better than 'normal' BF16 training, but still a bit shy of 'actual' FP32 trainings (which are now possible on 24GB GPUs via this PR with the 'fused' backward pass and optimizer step). |
Alright, you've convinced me. 😊 I'll grab your branch and see if I can notice any better results on some of my existing image sets. |
no changes only reduced speed |
@2kpr fp32 cooked my model with same LR of BF16 have you made any changes? i weren't able to research it yet |
Thank you for this. This looks really nice! I will review and merge this sooner. |
I've been running tests for a few hours now, and I'm seeing great improvements! Furkan is correct that the iterations go significantly more slowly, and so does the generation of sample images. But, I feel I'm also seeing good gains in image quality and learning ability. Fine-grained details that were being missed out on are being learned this time, which they weren't when I used bf16. And the images seem to stay 'in focus' better, without smudging so often. This is a recommended update. |
Not that I'm aware of. If you consider what is happening you might see why it would 'appear to cook' your model at that same learning rate as BF16, it is not that your model is being 'cooked' it's more like you are seeing a 'real learning' happening since FP32 is retaining all of the model update information whereas when you train with BF16 you are losing a lot of information due to the 'limited precision' of BF16. BF16 has the 'range' to match FP32, but BF16 has way less 'precision' than FP32 and therefore does not 'properly update the model' on each step when training in BF16, that is what 'stochastic rounding' tries to fix for instance, see here: https://storage.googleapis.com/deepmind-media/research/language-research/Training%20Gopher.pdf , check out pages 53 (starting at setion C.2) and 54. So, it isn't that this code or FP32 'cooked' your model, it is more apt to say that BF16 'undertrained' / 'undercooked' your model at the same LR and that BF16 was 'losing a lot of information during model updates' due to it's very limited 'precision'. Makes sense? |
NP, thanks, awesome, feel free to move things around to more 'fit in' with your existing codebase :) |
Awesome, thanks for the feedback :) Well, yes, that is just what happens when training in FP32 vs BF16 unfortunately, goes a bit slower re it/s, but the fact that the model is not losing information in updates is key, see here: #1259 (comment) |
I'm finding I can set the LR significantly lower when using fp32 training. Previously, I found that cranking the LR up actually got me better generalization, as lower LRs seemed to get stuck in local minimas. But I guess it's just that the weight updates to get out of these (apparent) minimas were too small for bf16 to represent. |
Cool, yea, that is a great way to put it :) |
You need to go w/ a much slower LR if using stochastic rounding or fp32 as numbers that were being rounded to 0 the prior method now have small values. I usually use 1e-5 for dreambooth training. In my handful of tests w/ stochastic and a couple w/ fp32 1e-6 is where I need to keep it to avoid cooking the model as you say. Run samples every 100 steps or so the and you'll see it converges much more quickly. |
Hi @2kpr, I think there is a bug somewhere in this implementation. My training runs start okay, but if I get to the first autosave point, and then stop the run and restart it from that autosave (instead of sdxl base), I get this error message:
I'm not saving out in fp16 or bf16 format, I switched to saving out in fp32. |
@2kpr - an update on the previous bug report: I only get that error message if I have --train_text_encoder set in the command line on the second run. That parameter works when starting from base, but not from an autosave checkpoint. |
I'm getting the following error upon attempting to begin a training run:
Update:
It seems the current max_grad_norm implementation is not compatible |
Looks great, is it possible to implement it on other optimizers? For example adamW |
My understanding of creating stuff from scratch in PyTorch is minimal, so just out of curiosity, why does this require implementing adafactor from scratch instead of using the base implementation and applying the additional steps as described in the documentation? I guess effectively, why is the change to sdxl_train.py not sufficient here? |
After looking through the code further and experimenting to answer my above question, I believe that while the current version of the code is functional, the logic should be made a little more generic to allow for additional fused versions of other optimizers to be incrementally added in addition to adafactor. This would likely involve:
|
Since I can't for the life of me figure out how to add a commit to this PR, I created an additional PR on 2kpr's repo with the AdamW optimizer added and the above changes made to make the logic more generic. I couldn't test the logic as it blew past my VRAM when I attempted to run it. It very well might be an issue with my implementation of AdamW, which I copied from the Transformers implementation PR Created: 2kpr#1 |
With this PR, Utilizing Adafactor and --fused_backward_pass, --mixed_precision="no" uses less VRAM than --mixed_precision="bf16", which I feel extremely suspicious about. Is anyone else having this happen? |
Have you tried w/o that mixed_precision line completely? If so your VRAM usage was the same as "no"? |
Here's what that looks like, no appreciable difference between "no" and removing the argument entirely. |
Cool, and for clarification, are you running the OPs PR or the PR created by stepfunction83 shared in the comments (2kpr#1) Thanks! |
Running with the OP PR. For the run itself these are the arguments if anyone can't reproduce it:
|
I can also confirm that I experienced the same thing when I was testing and forgot to turn off full FP16. |
Upon further inspection, the same behavior occurs in onetrainer with an independent fused adafactor implementation, so I guess that's just how it is ¯\_(ツ)_/¯ |
Not really that independent I think. Looking at the code, it looks like most of it was directly copied from my implementation. I can also explain why it's done this way, instead of the more generic implementation from here: I didn't want to create a new optimizer for every single parameter. Some optimizers (like prodigy) have some global state that's shared between all parameters. This wouldn't even work with the generic implementation. Also, I feel like there would be a lot of overhead, and saving/loading the optimizer state would also be harder. @2kpr Copying my code is fine. I made it open source for a reason. But at lease mention it somewhere 😉 btw: This implementation is missing the GradScaler check. Without it, mixed precision training with fp32 weights and fp16 autocasting will degrade a lot. |
To be fair, I didn't copy the code, for had I done so wouldn't the gradscaler been copied? :) I took the general files / code we had been sharing between ourselves in PM when I first sent my first two stochastic rounding implementations to you, after which you made your very nice and succinct SR implementation. I then took those files and in fact did look at your code in general how you implemented the fused backpass and saw you used the pytorch hook for such which I didn't know existed, after which I read the pytorch article on fusing the backward pass, and then from 'my files' hand wrote what is in this PR, and considering there is only one pytorch hook function, and we both copied/used/broke up the step method from Adafactor, of course the PR 'can look copied', but I assure you it wasn't. I also had to of course 'try' to integrate it all within the framework of sd-scripts, add the parameters, Japanese translations, and for instance add "if parameter.requires_grad:" since it was erroring otherwise, etc, aka putting in that work that shows it wasn't just some 'copy / paste' operation. Did I take inspiration from your code, sure, you were after all the first one to mention the fused backpass on discord, but in a similar vein, aka in a similar 'code inspiration sharing between you and me', I sent you my first couple stochastic rounding implementations via PM that you played with and then you developed your own 'simplified' and more 'performant' versions thereof. That all said, I will fully acknowledge that it was you in particular that recalled on the Stable Cascade discord and let everyone know about the potential benefits of stochastic rounding as a concept due to the old paper you referenced (of which I read, made my own SR implementation attempts, sent them to you, aka what I wrote above, etc), but that you were also the one to suggest the fusing of the backpass as a concept that could be used to great measure in addition to SR, etc :) So if you take any offense still to my not mentioning of you in this PR, I'm sorry, I hope the above clears things up. |
Curious. Would this affect mixed precision training in bf16 as well? |
No. The GradScaler only supports (and is only needed) when fp16 is involved. |
Thank you for this, and sorry for the delay. This works really well! I will do some more tests and merge into dev sooner. |
I'd like to add 'parameter groups' feature on #1319 to balance the memory savings and complexity of fused backward pass with the availability of an arbitrary optimizer. Any comments would be appreciated. |
Hi, I'm particularly interested in the stochastic rounding aspect. Will there be more PRs adding this feature in the future, or has the current PR already implemented it? Thank you for your work. |
Squashed commit of the following: commit 56bb81c Author: Kohya S <ykumeykume@gmail.com> Date: Wed Jun 12 21:39:35 2024 +0900 add grad_hook after restore state closes kohya-ss#1344 commit 22413a5 Merge: 3259928 18d7597 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Tue Jun 11 19:52:03 2024 +0900 Merge pull request kohya-ss#1359 from kohya-ss/train_resume_step Train resume step commit 18d7597 Author: Kohya S <ykumeykume@gmail.com> Date: Tue Jun 11 19:51:30 2024 +0900 update README commit 4a44188 Merge: 4dbcef4 3259928 Author: Kohya S <ykumeykume@gmail.com> Date: Tue Jun 11 19:27:37 2024 +0900 Merge branch 'dev' into train_resume_step commit 3259928 Merge: 1a104dc 5bfe5e4 Author: Kohya S <ykumeykume@gmail.com> Date: Sun Jun 9 19:26:42 2024 +0900 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev commit 1a104dc Author: Kohya S <ykumeykume@gmail.com> Date: Sun Jun 9 19:26:36 2024 +0900 make forward/backward pathes same ref kohya-ss#1363 commit 58fb648 Author: Kohya S <ykumeykume@gmail.com> Date: Sun Jun 9 19:26:09 2024 +0900 set static graph flag when DDP ref kohya-ss#1363 commit 5bfe5e4 Merge: e5bab69 4ecbac1 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Thu Jun 6 21:23:24 2024 +0900 Merge pull request kohya-ss#1361 from shirayu/update/github_actions/crate-ci/typos-1.21.0 Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close kohya-ss#1307) commit 4ecbac1 Author: Yuta Hayashibe <yuta@hayashibe.jp> Date: Wed Jun 5 16:31:44 2024 +0900 Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close kohya-ss#1307) commit 4dbcef4 Author: Kohya S <ykumeykume@gmail.com> Date: Tue Jun 4 21:26:55 2024 +0900 update for corner cases commit 321e24d Merge: e5bab69 3eb27ce Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Tue Jun 4 19:30:11 2024 +0900 Merge pull request kohya-ss#1353 from KohakuBlueleaf/train_resume_step Resume correct step for "resume from state" feature. commit e5bab69 Author: Kohya S <ykumeykume@gmail.com> Date: Sun Jun 2 21:11:40 2024 +0900 fix alpha mask without disk cache closes kohya-ss#1351, ref kohya-ss#1339 commit 3eb27ce Author: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri May 31 12:24:15 2024 +0800 Skip the final 1 step commit b2363f1 Author: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri May 31 12:20:20 2024 +0800 Final implementation commit 0d96e10 Merge: ffce3b5 fc85496 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon May 27 21:41:16 2024 +0900 Merge pull request kohya-ss#1339 from kohya-ss/alpha-masked-loss Alpha masked loss commit fc85496 Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 27 21:25:06 2024 +0900 update docs for masked loss commit 2870be9 Merge: 71ad3c0 ffce3b5 Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 27 21:08:43 2024 +0900 Merge branch 'dev' into alpha-masked-loss commit 71ad3c0 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon May 27 21:07:57 2024 +0900 Update masked_loss_README-ja.md add sample images commit ffce3b5 Merge: fb12b6d d50c1b3 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon May 27 21:00:46 2024 +0900 Merge pull request kohya-ss#1349 from rockerBOO/patch-4 Update issue link commit a4c3155 Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 27 20:59:40 2024 +0900 add doc for mask loss commit 58cadf4 Merge: e8cfd4b fb12b6d Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 27 20:02:32 2024 +0900 Merge branch 'dev' into alpha-masked-loss commit d50c1b3 Author: Dave Lage <rockerboo@gmail.com> Date: Mon May 27 01:11:01 2024 -0400 Update issue link commit e8cfd4b Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 26 22:01:37 2024 +0900 fix to work cond mask and alpha mask commit fb12b6d Merge: febc5c5 00513b9 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 26 19:45:03 2024 +0900 Merge pull request kohya-ss#1347 from rockerBOO/lora-plus-log-info Add LoRA+ LR Ratio info message to logger commit 00513b9 Author: rockerBOO <rockerboo@gmail.com> Date: Thu May 23 22:27:12 2024 -0400 Add LoRA+ LR Ratio info message to logger commit da6fea3 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 21:26:18 2024 +0900 simplify and update alpha mask to work with various cases commit f2dd43e Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 19:23:59 2024 +0900 revert kwargs to explicit declaration commit db67529 Author: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun May 19 19:07:25 2024 +0900 画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (kohya-ss#1223) * Add alpha_mask parameter and apply masked loss * Fix type hint in trim_and_resize_if_required function * Refactor code to use keyword arguments in train_util.py * Fix alpha mask flipping logic * Fix alpha mask initialization * Fix alpha_mask transformation * Cache alpha_mask * Update alpha_masks to be on CPU * Set flipped_alpha_masks to Null if option disabled * Check if alpha_mask is None * Set alpha_mask to None if option disabled * Add description of alpha_mask option to docs commit febc5c5 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 19:03:43 2024 +0900 update README commit 4c79812 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 19:00:32 2024 +0900 update README commit 38e4c60 Merge: e4d9e3c fc37437 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 19 18:55:50 2024 +0900 Merge pull request kohya-ss#1277 from Cauldrath/negative_learning Allow negative learning rate commit e4d9e3c Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 17:46:07 2024 +0900 remove dependency for omegaconf #ref 1284 commit de0e0b9 Merge: c68baae 5cb145d Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 19 17:39:15 2024 +0900 Merge pull request kohya-ss#1284 from sdbds/fix_traincontrolnet Fix train controlnet commit c68baae Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 17:21:04 2024 +0900 add `--log_config` option to enable/disable output training config commit 47187f7 Merge: e3ddd1f b886d0a Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 19 16:31:33 2024 +0900 Merge pull request kohya-ss#1285 from ccharest93/main Hyperparameter tracking commit e3ddd1f Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 16:26:10 2024 +0900 update README and format code commit 0640f01 Merge: 2f19175 793aeb9 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 19 16:23:01 2024 +0900 Merge pull request kohya-ss#1322 from aria1th/patch-1 Accelerate: fix get_trainable_params in controlnet-llite training commit 2f19175 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 19 15:38:37 2024 +0900 update README commit 146edce Author: Kohya S <ykumeykume@gmail.com> Date: Sat May 18 11:05:04 2024 +0900 support Diffusers' based SDXL LoRA key for inference commit 153764a Author: Kohya S <ykumeykume@gmail.com> Date: Wed May 15 20:21:49 2024 +0900 add prompt option '--f' for filename commit 589c2aa Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 13 21:20:37 2024 +0900 update README commit 16677da Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 22:15:07 2024 +0900 fix create_network_from_weights doesn't work commit a384bf2 Merge: 1c296f7 8db0cad Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 12 21:36:56 2024 +0900 Merge pull request kohya-ss#1313 from rockerBOO/patch-3 Add caption_separator to output for subset commit 1c296f7 Merge: e96a521 dbb7bb2 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 12 21:33:12 2024 +0900 Merge pull request kohya-ss#1312 from rockerBOO/patch-2 Fix caption_separator missing in subset schema commit e96a521 Merge: 39b82f2 fdbb03c Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 12 21:14:50 2024 +0900 Merge pull request kohya-ss#1291 from frodo821/patch-1 removed unnecessary `torch` import on line 115 commit 39b82f2 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 20:58:45 2024 +0900 update readme commit 3701507 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 20:56:56 2024 +0900 raise original error if error is occured in checking latents commit 7802093 Merge: 9ddb4d7 040e26f Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 12 20:46:25 2024 +0900 Merge pull request kohya-ss#1278 from Cauldrath/catch_latent_error_file Display name of error latent file commit 9ddb4d7 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 17:55:08 2024 +0900 update readme and help message etc. commit 8d1b1ac Merge: 02298e3 64916a3 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 12 17:43:44 2024 +0900 Merge pull request kohya-ss#1266 from Zovjsra/feature/disable-mmap Add "--disable_mmap_load_safetensors" parameter commit 02298e3 Merge: 1ffc0b3 4419041 Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 12 17:04:58 2024 +0900 Merge pull request kohya-ss#1331 from kohya-ss/lora-plus Lora plus commit 4419041 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 17:01:20 2024 +0900 update docs etc. commit 3c8193f Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 17:00:51 2024 +0900 revert lora+ for lora_fa commit c6a4370 Merge: e01e148 1ffc0b3 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 16:18:57 2024 +0900 Merge branch 'dev' into lora-plus commit 1ffc0b3 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 16:18:43 2024 +0900 fix typo commit e01e148 Merge: e9f3a62 7983d3d Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 16:17:52 2024 +0900 Merge branch 'dev' into lora-plus commit e9f3a62 Merge: 3fd8cdc c1ba0b4 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 16:17:27 2024 +0900 Merge branch 'dev' into lora-plus commit 7983d3d Merge: c1ba0b4 bee8cee Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun May 12 15:09:39 2024 +0900 Merge pull request kohya-ss#1319 from kohya-ss/fused-backward-pass Fused backward pass commit bee8cee Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 15:08:52 2024 +0900 update README for fused optimizer commit f3d2cf2 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 15:03:02 2024 +0900 update README for fused optimizer commit 6dbc23c Merge: 607e041 c1ba0b4 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 14:21:56 2024 +0900 Merge branch 'dev' into fused-backward-pass commit c1ba0b4 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 14:21:10 2024 +0900 update readme commit 607e041 Author: Kohya S <ykumeykume@gmail.com> Date: Sun May 12 14:16:41 2024 +0900 chore: Refactor optimizer group commit 793aeb9 Author: AngelBottomless <aria1th@naver.com> Date: Tue May 7 18:21:31 2024 +0900 fix get_trainable_params in controlnet-llite training commit b56d5f7 Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 6 21:35:39 2024 +0900 add experimental option to fuse params to optimizer groups commit 017b82e Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 6 15:05:42 2024 +0900 update help message for fused_backward_pass commit 2a359e0 Merge: 0540c33 4f203ce Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon May 6 15:01:56 2024 +0900 Merge pull request kohya-ss#1259 from 2kpr/fused_backward_pass Adafactor fused backward pass and optimizer step, lowers SDXL (@ 1024 resolution) VRAM usage to BF16(10GB)/FP32(16.4GB) commit 3fd8cdc Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 6 14:03:19 2024 +0900 fix dylora loraplus commit 7fe8150 Author: Kohya S <ykumeykume@gmail.com> Date: Mon May 6 11:09:32 2024 +0900 update loraplus on dylora/lofa_fa commit 52e64c6 Author: Kohya S <ykumeykume@gmail.com> Date: Sat May 4 18:43:52 2024 +0900 add debug log commit 58c2d85 Author: Kohya S <ykumeykume@gmail.com> Date: Fri May 3 22:18:20 2024 +0900 support block dim/lr for sdxl commit 8db0cad Author: Dave Lage <rockerboo@gmail.com> Date: Thu May 2 18:08:28 2024 -0400 Add caption_separator to output for subset commit dbb7bb2 Author: Dave Lage <rockerboo@gmail.com> Date: Thu May 2 17:39:35 2024 -0400 Fix caption_separator missing in subset schema commit 969f82a Author: Kohya S <ykumeykume@gmail.com> Date: Mon Apr 29 20:04:25 2024 +0900 move loraplus args from args to network_args, simplify log lr desc commit 834445a Merge: 0540c33 68467bd Author: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon Apr 29 18:05:12 2024 +0900 Merge pull request kohya-ss#1233 from rockerBOO/lora-plus Add LoRA+ support commit fdbb03c Author: frodo821 <sakaic2003@gmail.com> Date: Tue Apr 23 14:29:05 2024 +0900 removed unnecessary `torch` import on line 115 as per kohya-ss#1290 commit 040e26f Author: Cauldrath <bnjmnhanes@gmail.com> Date: Sun Apr 21 13:46:31 2024 -0400 Regenerate failed file If a latent file fails to load, print out the path and the error, then return false to regenerate it commit 5cb145d Author: 青龍聖者@bdsqlsz <qinglongshengzhe@gmail.com> Date: Sat Apr 20 21:56:24 2024 +0800 Update train_util.py commit b886d0a Author: Maatra <ccharest93@hotmail.com> Date: Sat Apr 20 14:36:47 2024 +0100 Cleaned typing to be in line with accelerate hyperparameters type resctrictions commit 4477116 Author: 青龍聖者@bdsqlsz <qinglongshengzhe@gmail.com> Date: Sat Apr 20 21:26:09 2024 +0800 fix train controlnet commit 2c9db5d Author: Maatra <ccharest93@hotmail.com> Date: Sat Apr 20 14:11:43 2024 +0100 passing filtered hyperparameters to accelerate commit fc37437 Author: Cauldrath <bnjmnhanes@gmail.com> Date: Thu Apr 18 23:29:01 2024 -0400 Allow negative learning rate This can be used to train away from a group of images you don't want As this moves the model away from a point instead of towards it, the change in the model is unbounded So, don't set it too low. -4e-7 seemed to work well. commit feefcf2 Author: Cauldrath <bnjmnhanes@gmail.com> Date: Thu Apr 18 23:15:36 2024 -0400 Display name of error latent file When trying to load stored latents, if an error occurs, this change will tell you what file failed to load Currently it will just tell you that something failed without telling you which file commit 64916a3 Author: Zovjsra <4703michael@gmail.com> Date: Tue Apr 16 16:40:08 2024 +0800 add disable_mmap to args commit 4f203ce Author: 2kpr <96332338+2kpr@users.noreply.github.com> Date: Sun Apr 14 09:56:58 2024 -0500 Fused backward pass commit 68467bd Author: rockerBOO <rockerboo@gmail.com> Date: Thu Apr 11 17:33:19 2024 -0400 Fix unset or invalid LR from making a param_group commit 75833e8 Author: rockerBOO <rockerboo@gmail.com> Date: Mon Apr 8 19:23:02 2024 -0400 Fix default LR, Add overall LoRA+ ratio, Add log `--loraplus_ratio` added for both TE and UNet Add log for lora+ commit 1933ab4 Author: rockerBOO <rockerboo@gmail.com> Date: Wed Apr 3 12:46:34 2024 -0400 Fix default_lr being applied commit c769160 Author: rockerBOO <rockerboo@gmail.com> Date: Mon Apr 1 15:43:04 2024 -0400 Add LoRA-FA for LoRA+ commit f99fe28 Author: rockerBOO <rockerboo@gmail.com> Date: Mon Apr 1 15:38:26 2024 -0400 Add LoRA+ support
Implemented a 'fused' backward pass and optimizer step when using Adafactor which massively lowers the VRAM used when training. If you want some more information on 'fusing' see here: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
For instance with SDXL (@ 1024 resolution), the VRAM used when training in BF16 drops to 10GB of VRAM, and in FP32 drops to 16.4GB of VRAM, allowing SDXL FP32 'normal' training on a 24GB GPU (which allows a batch size of up to about 16 for BF16 training and 6 for FP32 training when training on a 24GB GPU). Secondarily it allows SDXL 'normal' BF16 training on smaller 12GB/16GB GPUs, etc.
I only modified the sdxl_train.py so far so people could test out this PR and then make any changes necessary before extending this 'fused backward pass' to the other training py files in this repo, etc.
Feel free to edit accordingly, I just placed things where they seemed most logical according to the existing code in this repo.