Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for 4bit (nf4) and 8bit bitsandbytes quantization (3/3) #151

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

Rypo
Copy link

@Rypo Rypo commented Nov 28, 2024

Changes

  • Adds a quantization_config arg to from_pretrained to OmniGen. Adds model, vae, processor kwargs to OmniGenPipeline.
  • OmniGen expects a transformers.BitsAndBytesConfig which can then be passed to OmniGenPipeline
  • Adds a new requirement bitsandbytes==0.44.1

Usage

from transformers import BitsAndBytesConfig
from OmniGen import OmniGenPipeline, OmniGen

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4')
# quantization_config = BitsAndBytesConfig(load_in_8bit=True) # for 8-bit

model = OmniGen.from_pretrained("Shitao/OmniGen-v1", quantization_config=quantization_config)
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1", model=model)

or to use pre-quantized weights:

from OmniGen import OmniGenPipeline, OmniGen

# for 8-bit: 'gryan/OmniGen-v1-bnb-8bit'
model = OmniGen.from_pretrained('gryan/OmniGen-v1-bnb-4bit')
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1", model=model)

Important

If you are using Google Colab free-tier or have an older GPU use the float16 weights. You'll go OOM or get errors using the default bfloat16 weights.

model = OmniGen.from_pretrained('gryan/OmniGen-v1-fp16-bnb-4bit', dtype=torch.float16)
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1", model = model)

For use with app.py you can pass a cli arg --nbits or -b

python app.py --nbits 4

Ideally, this would be a gradio radio button component or something, but that's a task for another day.

Results

Following a similar format to the Different inference settings table.

For 4bit-nf4 quantized model on RTX 3090 GPU(24G):

Settings Only Text Text + Single Image Text + Two Images
use_kv_cache=False 6.8G, 1m16s 7.2G, 3m30s 7.7G, 5m47s
use_kv_cache=True 9.9G, 1m14s 20.4G†, 8m5s OOM (36.7G†, >1h10m)
use_kv_cache,offload_kv_cache 6.8G, 1m16s 7.2G, 2m49s 8.4G, 4m3s
use_kv_cache,offload_kv_cache,separate_cfg_infer 6.8G, 1m20s 7.0G, 2m31s 7.4G, 3m31s
use_kv_cache,offload_kv_cache,offload_model* 5.0G, 1m35s 6.0G, 3m7s 8.0G, 4m21s
use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model* 5.0G, 1m58s 5.3G, 3m29s 5.6G, 4m19s
  • † - memory_reserved > 24gb, RAM spillover
  • * - only VAE offload. Model loaded in 4bit cannot be offloaded.

Testing setup

  • Timings are reported on the first 3 examples from the inference.ipynb notebook.
    • For "Only Text", just the first prompt is used.
  • vRAM is what's reported by max_memory_allocated(), max_memory_reserved() was typically ~3GB higher.

Image Comparisons

prompt = "A vintage camera placed on the ground, ejecting a swirling cloud of Polaroid-style photographs into the air. The photos, showing landscapes, wildlife, and travel scenes, seem to defy gravity, floating upward in a vortex of motion. The camera emits a glowing, smoky light from within, enhancing the magical, surreal atmosphere. The dark background contrasts with the illuminated photos and camera, creating a dreamlike, nostalgic scene filled with vibrant colors and dynamic movement. Scattered photos are visible on the ground, further contributing to the idea of an explosion of captured memories."
pipe(prompt=prompt, height=1024, width=1024, guidance_scale=2.5,  seed=0, ...)

text_only_1111_4bit_bf16

prompt="The woman in <img><|image_1|></img> waves her hand happily in the crowd"
input_images=["./imgs/test_cases/zhang.png"]
pipe(prompt=prompt, input_images=input_images, height=1024, width=1024, guidance_scale=2.5, img_guidance_scale=1.8,  seed=42, ...)

single_img_1111_4bit_bf16

prompt = "Two woman are raising fried chicken legs in a bar. A woman is <img><|image_1|></img>. Another woman is <img><|image_2|></img>."
input_images = ["./imgs/test_cases/mckenna.jpg", "./imgs/test_cases/Amanda.jpg"]
pipe(prompt=prompt, input_images=input_images, height=1024, width=1024,guidance_scale=2.5, img_guidance_scale=1.8, max_input_image_size=1024, seed=168, ...)

double_img_1111_4bit_bf16

8-bit

I didn't spend much time testing 8-bit, but without cache offloading -> OOM. Here's a couple samples otherwise.

For bnb 8bit quantized model on RTX 3090 GPU(24G):

Settings Only Text Text + Single Image Text + Two Images
use_kv_cache,offload_kv_cache,separate_cfg_infer 8.6G, 1m17s 8.8G, 2m39s 9.2G, 3m42s
use_kv_cache,offload_kv_cache,separate_cfg_infer,offload_model 8.2G, 2m28s 8.4G, 4m17s 8.8G, 5m10s

Images

Same prompts + settings as above, all with 8-bit quantization.

to_si_di_1110_8bit

Additional Considerations

  • I briefly experimented with casting the VAE to bfloat16. The outputs appeared identical from my tests. There are numerical stability issues that can arise from reduced precision without careful handling, however. See diffusers. That said, it can save 1-2GB if you're willing to risk black outputs.
  • Tried bnb_4bit_quant_type='fp4' - same vRAM, same timings, worse quality images.
  • Tried bnb_4bit_compute_dtype=torch.bfloat16 - very poor quality images.

_This is the third of 3 PRs I'm issuing to improve performance/fix errors. I've tried to keep each incremental change as small in scope as possible. PRs: 1. #149, 2. #150, 3. This

Update (2024-12-02):

Update (2024-12-05):

  • Usage information changed significantly

Update (2024-12-12):

  • Added float16 quant usage information

Rypo added 3 commits November 25, 2024 19:39
Prevents slow CPU initialization of model weights on load by using accelerate `init_empty_weights`.

Completely compatible with from_pretrained since weights will always be overwritten by state_dict

fixes VectorSpaceLab#72
@staoxiao
Copy link
Contributor

@Rypo, thank you very much for your contribution!!! These are incredibly useful features.
For this PR: #150, I previously used this method self.prefetch_stream.synchronize(), but encountered an error. I'll need to take some time to check it again.

@staoxiao
Copy link
Contributor

I find that removing non_blocking=True will result in a longer inference time. When using two images as input and offloading the model, the generation process will take 10 minutes. But if add non_blocking=True, this only takes 2.5 mins.
So I suggest still using non_blocking=True for .to("cpu") calls.

@Pevernow
Copy link

Pevernow commented Dec 1, 2024

@Rypo
This PR is excellent!
Unfortunately, it still cannot run in a RAM-constrained environment like colab (<12gb).
Can you provide a pre-quantized model and corresponding loading solution on huggingface?
Thank you very much

@Rypo
Copy link
Author

Rypo commented Dec 2, 2024

@staoxiao Absolutely, happy to lend a hand!

I dug a little deeper into the torch.cuda.synchronize(self.prefetch_stream) statement and have determined that it is functionally equivalent to torch.cuda.synchronize(None). If synchronize isn't passed a str, int, or torch.device, torch will default to synchronizing all streams on the current device.

I assume that isn't the behavior you're after here, but at worst it would probably just cause a small performance hit. Not exactly sure why self.prefetch_stream.synchronize() would cause errors on your end though.

As for the non_blocking issue, that's fair. I'll separate #150 so it does not interfere with the other PRs. I will caution, however, that WSL users may not be able to use Omnigen in that case. I suspect that #90 and #117 are both instances of this. So I'd recommend leaving that PR open for the time being for visibility.

You'll also need to take care to always synchronize before accessing any tensors moved via to("cpu", non_blocking=True). Unlike .to(cuda, non_blocking), there's no read time validation. If the data is accessed without explicit synchronization before the transfer completes, you may get garbage outputs.

@Rypo
Copy link
Author

Rypo commented Dec 2, 2024

@Pevernow I uploaded the weights to the hub (4bit, 8bit). They will not work out of the box yet. I have a very rough prototype that I'm working on cleaning up. It should be ready in the coming days.

Note: These links may change depending on the specifics of the final implementation. I'll update this message if so.


Update: they work out of the box with this PR now. See updated "Usage" section above.

Rypo added 2 commits December 2, 2024 16:25
Add a quantization utility for HFQuantizers.
Modify pipelines to accept quantization_config.
Sets ground work for allow bf16 vae.
Update requirements to include bitsandbytes.

closes VectorSpaceLab#45, closes VectorSpaceLab#64
@Rypo Rypo force-pushed the bnb_quantization branch from accf137 to 8ea2d6d Compare December 2, 2024 22:48
@nitinmukesh
Copy link

@Rypo

Looking forward to integration of quantized weights. Thank you

@Rypo
Copy link
Author

Rypo commented Dec 6, 2024

Alright, I think it's in an acceptable state at this point. Barring any glaring issues I missed, I'm calling it a wrap on this PR.

New Changes Recap

  1. Fix RuntimeError: CUDA error: out of memory on CPU transfer (2/3) #150 was pulled out of the chain. Absolutely nothing on this branch touches a Stream or a non_blocking.
  2. Created HF repos for pre-quantized weights. gryan/OmniGen-v1-bnb-4bit and gryan/OmniGen-v1-bnb-8bit.
  3. Supports loading pre-quantized weights straight from the hub! See updated Usage section above or either of the HF repos.

Final Remarks

  • I focused my changes exclusively on inference. The only thing I did with regard to training is run the LoRA fine-tuning examples to make sure I didn't break anything.

  • These changes should be considered transitory. If the ultimate goal is full-fledged integration with Diffusers, there's nothing I did that could not be done simpler/better by subclassing from diffusers/transformers. I mirrored the syntax wherever possible to hopefully ease future migration attempts.

  • To preview all 3 PRs merged into main you can install from my 'nightly' branch (pip install git+https://github.com/Rypo/OmniGen.git@nightly or replace 'nightly' with 'bnb_quantization' for just this PR).

    • Installing from this repo, I was able to run the double image example at 768x768 using the pre-quantized 4bit model on free-tier google colab. It was painfully slow, but it did run to completion.

If anyone finds an issue let me know, otherwise enjoy!


Update: I failed to call it a wrap. Colab painfully slow -> painfully reasonably slow. See notes on float16 update below.

@nitinmukesh
Copy link

@Rypo

Appreciate your efforts.

@Pevernow
Copy link

Pevernow commented Dec 7, 2024

{D780023B-C4AF-452A-ABB0-4D7AB41C6A20}
slow but work
thx!

Adds a small utility to scheduler to find the minimum clip bound to prevent NaNs from popping out of the decoder layers. Search over hardcoded buffer to discard as little information as possible.

Phi3Transformer now raises OverflowError when NaNs encountered.

Initialize model dtype based on actual weight value to avoid bad casts when quantized.
@Rypo
Copy link
Author

Rypo commented Dec 12, 2024

Update

Some good news for the GPU poor!

My most recent commit (d75af76) appears to have patched the float16 issue #108. Turns out the decoder layers were operating on values outside the bounds of what fp16 can handle causing numerical overflow. Luckily, clipping the values into an operable range doesn't seem to degrade the quality too much.

Why does float16 matter?

  • Older GPU may not support bfloat16
  • Colab free tier T4s only support emulated bf16. This results in significantly higher vram usage and slower computation.

The Goods

I uploaded fp16-compatible 4-bit weights to the hub: gryan/OmniGen-v1-fp16-bnb-4bit

With these weights, you can comfortably run double 1024x1024 images on free-tier Colab, without model offloading.

model = OmniGen.from_pretrained('gryan/OmniGen-v1-fp16-bnb-4bit', dtype=torch.float16)
pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1", model = model)

For maximum comfort

  • Install from my 'nightly' branch (pip install git+https://github.com/Rypo/OmniGen.git@nightly) to disable non_blocking CPU transfers.
  • It takes ~11 minutes for a double 1024x1024, you'll have a solid 6GB of vRAM and 5GB of RAM clearance on Colab free tier.

For maximum speed

  • Install from this PR or my 'bnb_quantization' branch
  • It takes ~9 minutes for a double, you'll have about the same vRAM clearance but you'll be dangerously close to exhausting your RAM, ~200mb clearance.

It's still not fast, but it's a decent step up from the 25-60 min it took for a double 768x768 previously. Enjoy!

Rypo added 2 commits December 16, 2024 19:11
Start search with minimal clipping value found through testing (2^16 - 3*32). This value was sufficient for all tested inputs. Further analysis still required to guarantee that it will always be  sufficient in all cases.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants