Skip to content

HighCWu/flux-4bit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flux.1 with 4 bit Quantization

Badge Model    Badge Colab

I want to train flux's LoRA using the diffusers library on my 16GB GPU, but it's difficult to train with flux-dev-fp8, so I want to use 4-bit weights to save VRAM.

I found that flux's text_encoder_2 (t5xxl) quantized with bnb_nf4 is not as good as hqq_4bit, and flux's transformer quantized with hqq_4bit is not as good as bnb_nf4, so I used different quantization methods for the two models.

In inference mode, using cpu_offload takes up 8.5GB of VRAM, and when it is not turned on, it takes up 11GB of VRAM.

If you want to use less VRAM during training, you can consider storing the results of text_encoder_2 as a dataset first.

Note I used some patch code to make the diffusers model load the quantized weights properly.

How to use

  1. clone the repo:

    git clone https://github.com/HighCWu/flux-4bit
    cd flux-4bit
  2. install requirements:

    pip install -r requirements.txt
  3. run in python:

    import torch
    
    from model import T5EncoderModel, FluxTransformer2DModel
    from diffusers import FluxPipeline
    
    
    text_encoder_2: T5EncoderModel = T5EncoderModel.from_pretrained(
        "HighCWu/FLUX.1-dev-4bit",
        subfolder="text_encoder_2",
        torch_dtype=torch.bfloat16,
        # hqq_4bit_compute_dtype=torch.float32,
    )
    
    transformer: FluxTransformer2DModel = FluxTransformer2DModel.from_pretrained(
        "HighCWu/FLUX.1-dev-4bit",
        subfolder="transformer",
        torch_dtype=torch.bfloat16,
    )
    
    pipe: FluxPipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        text_encoder_2=text_encoder_2,
        transformer=transformer,
        torch_dtype=torch.bfloat16,
    )
    pipe.enable_model_cpu_offload() # with cpu offload, it cost 8.5GB vram
    # pipe.remove_all_hooks()
    # pipe = pipe.to('cuda') # without cpu offload, it cost 11GB vram
    
    prompt = "realistic, best quality, extremely detailed, ray tracing, photorealistic, A blue cat holding a sign that says hello world"
    image = pipe(
        prompt,
        height=1024,
        width=1024,
        guidance_scale=3.5,
        output_type="pil",
        num_inference_steps=16,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]
    image.show()

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published