Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support SD3 #1374

Draft
wants to merge 136 commits into
base: dev
Choose a base branch
from
Draft

support SD3 #1374

wants to merge 136 commits into from

Conversation

kohya-ss
Copy link
Owner

@kohya-ss kohya-ss commented Jun 15, 2024

  • Replace SD3Tokenizer with the original CLIP-L/G/T5 tokenizers.
  • Extend the max token length to 256 for T5XXL.
  • Refactor caching for latents.
  • Refactor caching for Text Encoder outputs
  • Extract architecture-dependent parts from datasets.
  • Refactor SD/SDXL training scripts.
  • Caching attention mask etc.
  • Enable training for CLIP-L/G.
  • Add an option to use T5XXL from transformers (for fp8 quantized ver.)
  • Add attention mask for T5XXL embeds (?). https://www.reddit.com/r/StableDiffusion/comments/1e6k59c/solution_discovered_partially_implemented_for_sd3/
  • Sample images during training.
  • Cache Text Encoder outputs for sampling.
  • Update SD/SDXL sampling to use refactored Text Encoding etc.
  • Update gen_img.py to use refactored Text Encoding etc.
  • SD3 LoRA support.
  • FLUX.1 fine tuning.
  • FLUX.1 LoRA support for FLUX.
  • FLUX.1 LoRA support for CLIP-L.
  • FLUX.1 masking for attention
  • FLUX.1 Sample image generation during training.

@bghira
Copy link

bghira commented Jun 16, 2024

this is a chance to just use Diffusers modules instead of doing everything from scratch. why not take it?

@kohya-ss
Copy link
Owner Author

There are several reasons for this, but the biggest reason is that it is difficult to extend. For example, LoRA, custom ControlNet and Deep Shrink etc.

Also, considering the various processes in the training scripts, such as conditional loss, SNR, masked loss, etc., the training scripts need to be written from scratch.

@bghira
Copy link

bghira commented Jun 16, 2024

all of that is done via peft other than deepshrink but you can make a pipeline callback for that.

@bghira
Copy link

bghira commented Jun 16, 2024

i mean to use the sd3 transformer module from the diffusers project.

it is frustrating to see bespoke versions of things with unreadable comments always in this repository. can you at least leave better comments?

@kohya-ss
Copy link
Owner Author

I think transformer module should be extendable for the future. In addition, SD3 transformer is based on sd3-ref (Stability AI official repo), and modified by KBlueLeaf to support xformers etc. So it is prior to Diffusers, and not full scratch. I appreciate your understanding.

I will add better comments in future codes, including this PR.

@araleza
Copy link

araleza commented Jul 10, 2024

Hello, I have been trying out SD3 training. It seems to be working pretty well. 😊

One thing I noticed is that generation of sample images while training is not yet implemented. This made it hard for me to see how my SD3 training was going, and make adjustments.

Implementing full support for all the sample images was difficult, but I found a cheap way to get most features working, and now I have sample images working again. This code is not properly integrated with the usual sample image generation code, but if people want to use it while they wait for a real well-integrated implementation, it does the basics of what's needed.

Just go into your sd3_train.py file, and find this commented-out section:

                # sdxl_train_util.sample_images(
                #     accelerator,
                #     args,
                #     None,
                #     global_step,
                #     accelerator.device,
                #     vae,
                #     [tokenizer1, tokenizer2],
                #     [text_encoder1, text_encoder2],
                #     mmdit,
                # )

and replace that with this:

                # Generate sample images
                if args.sample_every_n_steps is not None and global_step % args.sample_every_n_steps == 0:
                    from sd3_minimal_inference import do_sample
                    from PIL import Image
                    import datetime
                    import numpy as np
                    import shlex
                    import random

                    assert args.save_t5xxl, "When generating sample images in SD3, --save_t5xxl parameter must be set"

                    with open(args.sample_prompts, 'r') as file:
                        lines = [line.strip() for line in file if line.strip()]

                    vae.to("cuda")
                    for line in lines:
                        logger.info(f"Generating image: {line}")

                        if line.find('--') != -1:
                            prompt = line[:line.find('--') - 1].strip()
                            line = line[line.find('--'):]
                        else:
                            prompt = line
                            line = ''

                        parser_s = argparse.ArgumentParser()
                        parser_s.add_argument("--w", type=int, action="store", default=1024, help="image width")
                        parser_s.add_argument("--h", type=int, action="store", default=1024, help="image height")
                        parser_s.add_argument("--s", type=int, action="store", default=30,   help="sample steps")
                        parser_s.add_argument("--l", type=int, action="store", default=4,    help="CFG")
                        parser_s.add_argument("--d", type=int, action="store", default=random.randint(0, 2**32 - 1), help="seed")
                        prompt_args = shlex.split(line)
                        args_s = parser_s.parse_args(prompt_args)

                        # prepare embeddings
                        lg_out, t5_out, pooled = sd3_utils.get_cond(prompt, sd3_tokenizer, clip_l, clip_g, t5xxl) # +'ve prompt
                        cond = torch.cat([lg_out, t5_out], dim=-2), pooled

                        lg_out, t5_out, pooled = sd3_utils.get_cond("", sd3_tokenizer, clip_l, clip_g, t5xxl) # No -'ve prompt
                        neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled

                        latent_sampled = do_sample(
                            args_s.h, args_s.w, None, args_s.d, cond, neg_cond, mmdit, args_s.s, args_s.l, weight_dtype, accelerator.device
                        )

                        # latent to image
                        with torch.no_grad():
                            image = vae.decode(latent_sampled)
                        image = image.float()
                        image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
                        decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
                        decoded_np = decoded_np.astype(np.uint8)
                        out_image = Image.fromarray(decoded_np)

                        # save image
                        output_dir = os.path.join(args.output_dir, "sample")
                        os.makedirs(output_dir, exist_ok=True)
                        output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
                        out_image.save(output_path)

                    vae.to("cpu")

It supports a caption followed by the usual optional --w, --h, --s, --l, --d (for width, height, steps, cfg, and seed). It doesn't support negative captions, and it won't work right with captions longer than 75 tokens.

I'm finding sample image generation to be helpful. For example, I notice that most of my sample output images start off by looking brighter than expected (with white or bright backgrounds). Edit: Might have been my cfg of 7.5; SD3 seems to want lower cfgs. I had to push the sample count up as the cfg was lowered. Image quality still seems poor though, compared to what some people are getting out of SD3.

@araleza
Copy link

araleza commented Jul 10, 2024

Think I've found an issue that's causing the poor quality SD3 samples. The do_sample() function is not filling in the shift parameter that's required by SD3, and it's defaulting to 1.0 instead of the recommended 3.0:

class ModelSamplingDiscreteFlow:
    """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""

    def __init__(self, shift=1.0):
        self.shift = shift
        timesteps = 1000
        self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))

From sd-script's sd3_minimal_inference.py function, do_sample()

    model_sampling = sd3_utils.ModelSamplingDiscreteFlow()

From the SD3 paper:
image

The paper also seems to say that these shifts to the sigmas should be present during training. Are these maybe missing too, @kohya-ss? (Edit: No, a shift value of 3.0 is already set up correctly during training)

@kohya-ss
Copy link
Owner Author

Think I've found an issue that's causing the poor quality SD3 samples. The do_sample() function is not filling in the shift parameter that's required by SD3, and it's defaulting to 1.0 instead of the recommended 3.0:

Thank you! I fixed it. The generated images seemed to be better now.

@kohya-ss
Copy link
Owner Author

I agree that the sample image generation is really useful. In my understanding, T5XXL is on CPU, so I wonder get_cond may take a long time. How much time it takes?

I think it might be necessary to get TE's output for the sampling prompt in advance, at the same time the TE caching. However, if T5XXL works on CPU with an acceptable time, the implementation of the sample generation will be much easier (like your implementation :) .

@bghira
Copy link

bghira commented Jul 11, 2024

it takes about 30-50 seconds to run T5 XL on the CPU, i think XXL is even worse latency for each embed

@araleza
Copy link

araleza commented Jul 11, 2024

I agree that the sample image generation is really useful. In my understanding, T5XXL is on CPU, so I wonder get_cond may take a long time. How much time it takes?

@kohya-ss, the calls to get_cond() only take around 2 seconds each on my machine. The whole sample image generation takes just 16 seconds per image for me, and I am still doing 80 sample steps for the images. :D

My PC is an ordinary (but good) home PC machine with a 13th gen Intel i7, and I've got 64 GB of CPU RAM. Perhaps the people finding the T5 XL to be very slow are running out of CPU memory and swapping the T5 XL out to disk without realizing? @bghira

@FurkanGozukara
Copy link

@kohya-ss thank you for reply

even 512 fails can you check if something wrong?

full train logs : https://gist.github.com/FurkanGozukara/b13e2c263138afd5e8548eb6ae9786ce

toml : https://gist.github.com/FurkanGozukara/f01c76c4eaa2172352ebf1b8e08a395f

@bghira
Copy link

bghira commented Sep 17, 2024

no i don't think the fused backward pass works with DDP. it's an Accelerate thing

@kohya-ss
Copy link
Owner Author

even 512 fails can you check if something wrong?

There doesn't seem to be a problem with the settings, but the memory is really tight, so it may also depend on the environment. Things may be different if you have 3 or more GPUs. Here's my case:

|   0  NVIDIA RTX A6000             WDDM  |   00000000:01:00.0 Off |                  Off |
| 30%   50C    P2             99W /  180W |   48406MiB /  49140MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000             WDDM  |   00000000:05:00.0 Off |                  Off |
| 30%   50C    P2            101W /  180W |   48424MiB /  49140MiB |    100%      Default |
|                                         |                        |                  N/A |

no i don't think the fused backward pass works with DDP. it's an Accelerate thing

Hmm... I don't know the details, but it seems that removing --fused_backward_pass causes OOM.

@FurkanGozukara
Copy link

even 512 fails can you check if something wrong?

There doesn't seem to be a problem with the settings, but the memory is really tight, so it may also depend on the environment. Things may be different if you have 3 or more GPUs. Here's my case:

|   0  NVIDIA RTX A6000             WDDM  |   00000000:01:00.0 Off |                  Off |
| 30%   50C    P2             99W /  180W |   48406MiB /  49140MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000             WDDM  |   00000000:05:00.0 Off |                  Off |
| 30%   50C    P2            101W /  180W |   48424MiB /  49140MiB |    100%      Default |
|                                         |                        |                  N/A |

no i don't think the fused backward pass works with DDP. it's an Accelerate thing

Hmm... I don't know the details, but it seems that removing --fused_backward_pass causes OOM.

thank you so much this is 512px right?

i will try different torch and accelerator lets see if helps

@kohya-ss
Copy link
Owner Author

thank you so much this is 512px right?

Yes, it is 512x512 and batch size=1.

@kohya-ss
Copy link
Owner Author

kohya-ss commented Sep 18, 2024

Fixed a bug where train/eval was not called correctly in schedule free.

train()/eval() is called at every step, so if training becomes slow, please let me know. I will add another fix if that happens.

@FurkanGozukara
Copy link

@kohya-ss which would be the proper repo to open an issue regarding the DDP overhead when training FLUX? single GPU taking 25 GB becomes more than 48 GB even in 2x GPU :/

@recris
Copy link

recris commented Sep 18, 2024

Just did a LoRA test with AdamWScheduleFree, it seems to work much better than plain AdamW plus cosine or constant schedule.

@araleza
Copy link

araleza commented Sep 18, 2024

Just did a LoRA test with AdamWScheduleFree, it seems to work much better than plain AdamW plus cosine schedule.

I just did the same, and it seems much better to me too. I had Adafactor before I switched over. I'm planning to use AdamWScheduleFree as my new default optimizer.

(Edit: I also deleted --max_grad_norm 1.0 from my command line at the same time too, which might actually be the cause of part or all of the gains I saw, but it's more likely the improvements are due to the use of AdamWScheduleFree optimizer, especially as recris says that he didn't notice much difference with/without max_grad_norm)

Thanks go to @sdbds (and @kohya-ss too) for the implementation, and for recognizing that it was worth adding to sd-scripts.

@recris
Copy link

recris commented Sep 18, 2024

With AdamWScheduleFree, average_key_norm seems to grow significantly slower:
image
yellow = AdamW + wide cosine (T_max=2000)
purple = AdamScheduleFree

I am guessing less extreme learned weights result in better outputs?

With Flux I found that setting scale_weight_norms = 1.0, which worked fine in SDXL, was causing a significant dampening effect on the learning, and needed to set a much higher value to compensate.

@araleza
Copy link

araleza commented Sep 18, 2024

With AdamWScheduleFree, average_key_norm seems to grow significantly slower:

I actually continued training LoRAs that had been trained for thousands of steps with Adafactor (not schedule-free) with the new schedule-free optimizer (and also got rid of my max_grad_norm=1.0), and saw immediate image quality improvements. The key lengths probably didn't shrink way back in the first 30 or so steps before I output sample images, so I guess the improvements can't be entirely down to shorter key lengths?

@recris
Copy link

recris commented Sep 18, 2024

With AdamWScheduleFree, average_key_norm seems to grow significantly slower:

I actually continued training LoRAs that had been trained for thousands of steps with Adafactor (not schedule-free) with the new schedule-free optimizer (and also got rid of my max_grad_norm=1.0), and saw immediate image quality improvements. The key lengths probably didn't shrink way back in the first 30 or so steps before I output sample images, so I guess the improvements can't be entirely down to shorter key lengths?

Note that I mentioned the setting scale_weight_norms and not max_grad_norm, they do different things. However, I've not seen significant differences with or without max_grad_norm.

@FurkanGozukara
Copy link

Just did a LoRA test with AdamWScheduleFree, it seems to work much better than plain AdamW plus cosine schedule.

I just did the same, and it seems much better to me too. I had Adafactor before I switched over. I'm planning to use AdamWScheduleFree as my new default optimizer.

(Edit: I also deleted --max_grad_norm 1.0 from my command line at the same time too, which might actually be the cause of part or all of the gains I saw, but it's more likely the improvements are due to the use of AdamWScheduleFree optimizer, especially as recris says that he didn't notice much difference with/without max_grad_norm)

Thanks go to @sdbds (and @kohya-ss too) for the implementation, and for recognizing that it was worth adding to sd-scripts.

do you need changes in learning rate and VRAM usage changed?

@araleza
Copy link

araleza commented Sep 18, 2024

do you need changes in learning rate and VRAM usage changed?

The Facebookresearch page (which has the code for this optimizer) says a higher LR is suggested:

For AdamW, learning rates in the range 1x-10x larger than with schedule-based approaches seem to work.

In my case, I left my LoRA training rate at 8e-5 and saw immediate quality gains 30 iterations later when my first sample images were generated. But, that's picking up the LoRA from where I left off training it before. I haven't tried training a new one from scratch. Maybe that would need a higher LR?

I haven't tested the memory requirements, but the Facebookresearch page says:

Only two sequences need to be stored at a time (the third can be computed from the other two on the fly) so this method has the same memory requirements as the base optimizer (parameter buffer + momentum).

We're currently not using their 'wrapper' version of the optimizer which uses more memory.

The page I'm quoting from is here, if you want more detail:
https://github.com/facebookresearch/schedule_free

@FurkanGozukara
Copy link

@araleza thank you

which Optimizer extra arguments it needs? like weight_decay=0.01 or any other?

sadly it uses more VRAM than adafactor 27700 MB vs 29000 at fp16, but it is faster per step

@recris
Copy link

recris commented Sep 18, 2024

I've been playing with Flux for almost 3 weeks and found some interesting things I'd like others to corroborate:

  • Flux seems to like larger batch sizes. For example, training for 500 steps with batch size 6 gives a slightly better result than training 1000 steps with batch size 3. In prior models (SD1, SDXL) the difference wasn't as noticeable.
    • Update: with AdamWScheduleFree the behavior seems to be more in line with SDXL
  • Applying max norm regularization needs a higher limit than prior models. On the use-cases I tested, scale_weight_norms=1.0 doesn't negatively affect the learning in SD models, but in Flux it does. I found myself using values of 3.0 and greater. (this might depend on the choice of optimizer and needs further testing)
  • For many use cases we don't need to train at 1024px resolution - for example, for training a person likeness, 640px resolution already gives results way better than SDXL could do 1024px, and on my hardware that requires a third of compute resources (I can do batch size 3 instead of 1). Flux is able to render details learned at 640px in higher resolutions with very sharp image quality. Maybe it loses the ability to learn some of the more fine details of the training images but for me the gains in GPU performance are worth it. It also seems to learn faster (less steps) at lower resolutions too!
  • In Flux, over-fitting symptoms are a lot harder to perceive than in previous models. The model does a very good job at producing coherent images without the typical over-fitness artifacts. I've noticed that over-fitted LoRAs tend to lose prompt adherence first and only in more extreme cases visual artifacts start to show up.

@kohya-ss
Copy link
Owner Author

@kohya-ss which would be the proper repo to open an issue regarding the DDP overhead when training FLUX? single GPU taking 25 GB becomes more than 48 GB even in 2x GPU :/

This overhead of DDP seems to be expected, and more efficient training seems to require DeepSpeed ​​or FSDP etc.

@bghira
Copy link

bghira commented Sep 19, 2024

well deepspeed only works with cpu based optimiser AdamW and fsdp is all or nothing for sharding and slower than ZeRO

emcmanus and others added 3 commits September 19, 2024 14:30
Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket.

This codepath is entered on the last line, here:
```
def trim_and_resize_if_required(
    random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
    image_height, image_width = image.shape[0:2]
    original_size = (image_width, image_height)  # size before resize

    if image_width != resized_size[0] or image_height != resized_size[1]:
        # リサイズする
        if image_width > resized_size[0] and image_height > resized_size[1]:
            image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
        else:
            image = pil_resize(image, resized_size)
```
@FurkanGozukara
Copy link

FurkanGozukara commented Sep 20, 2024

@kohya-ss is this expected or because i used --debug?

i got 20800 images with captions

doing a FLUX fine tuning

it has been over 64100 steps and not a single cache file generated yet

image

@kohya-ss yes after removing --debug and starting again fixed :/

image

kohya-ss and others added 3 commits September 20, 2024 22:16
fix: backward compatibility for text_encoder_lr
Retain alpha in `pil_resize` for `--alpha_mask`
@sdbds
Copy link
Contributor

sdbds commented Sep 21, 2024

Thanks to everyone in the community for conducting the tests, @araleza @recris
I would suggest to consider increasing gradient_accumulation_steps to enlarge the batch size, like 6 or 8 or so.
Flux looks different from previous SD models in that it's more likely to crash first so a larger equivalent batch size might work better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.