Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] Add quantization support for bitsandbytes #9213

Merged
merged 119 commits into from
Oct 21, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 19, 2024

What does this PR do?

Come back later.

  • Quantization config class (base and bitsandbytes)
  • Quantizer class (base and bitsandbytes)
  • Utilities related to bitsandbytes
  • from_pretrained() at the ModelMixin level and related changes
  • save_pretrained()
  • NF4 tests
  • INT8 (llm.int8()) tests
  • Docs

Notes

  • Even though I alluded to having a separate QuantizationLoaderMixin in [Quantization] bring quantization to diffusers core #9174, I realized that is not an approach we can take because loading and saving a quantized model is very much baked into the arguments of ModelMixin.save_pretrained() and ModelMixin.from_pretrained(). It is deeply entangled.
  • For the initial quantization support, I think it's okay to not allow passing device_map, because for a pipeline, multiple device_maps can get ugly. This will be dealt with in a follow-up PR by @SunMarc and myself.
  • For the point above, for checkpoints that are found to be sharded (Flux, for example), I have decided to merge them on CPU to simplify the implementation. This will be dealt with in a follow-up PR by @SunMarc.
  • The PR has an extensive testing suite covering training, too. However, I have decided not to add it to our CI yet. We should first let this feature flow into the community and then add the tests to our nightly CI.

No-frills code snippets

Serialization
import torch 
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from accelerate.utils import compute_module_sizes

model_id = "black-forest-labs/FLUX.1-dev"

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model_nf4 = FluxTransformer2DModel.from_pretrained(
    model_id, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.bfloat16
)
assert model_nf4.dtype == torch.uint8, model_nf4.dtype
print(model_nf4.dtype)
print(model_nf4.config.quantization_config)
print(compute_module_sizes(model_nf4)[""] / 1024 / 1024)

push_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4.push_to_hub(push_id)

Serialized checkpoint: https://huggingface.co/sayakpaul/flux.1-dev-nf4-with-bnb-integration.

NF4 checkpoints of Flux transformer and T5: https://huggingface.co/sayakpaul/flux.1-dev-nf4-pkg (has Colab Notebooks, too).

Inference
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
nf4_id = "sayakpaul/flux.1-dev-nf4-with-bnb-integration"
model_nf4 = FluxTransformer2DModel.from_pretrained(nf4_id, torch_dtype=torch.bfloat16)
print(model_nf4.dtype)
print(model_nf4.config.quantization_config)

pipe = FluxPipeline.from_pretrained(model_id, transformer=model_nf4, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev-loaded.png")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this ! I see that you used a lot of things from transformers. Do you think it is possible to import these (or inherit) from transformers ? This will help reducing the maintenance. I'm fine also doing that since there are not too many follow-up PR after a quantizer has been added. About the HfQuantizer class, there are a lot of methods that were created to fit transformers structure. I'm not sure we will need eveyone of these methods in diffusers. Ofc, we can still do a follow-up PR to clean up.

src/diffusers/quantizers/base.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 20, 2024

@SunMarc I am guilty as charged but we don’t have transformers as a hard dependency for loading models in Diffusers. Pinging @DN6 to seek his opinion.

Update: Chatted with @DN6 as well. We think it's better to redefine inside diffusers without the transformers specific bits which we can clean in this PR.

@sayakpaul
Copy link
Member Author

@SunMarc I think this PR is ready for another review.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @sayakpaul !

src/diffusers/quantizers/base.py Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes sense to have this as a separate PR to add a base class because it's hard to understand what methods are needed - we should only introduce a minimum base class and gradually add functionalities as needed

can we have a PR with a minimum example working?

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 22, 2024

Okay, so, do you want me to add everything needed for bitsandbytes integration in this PR? But do note that this won’t be very different from what we have in transformers.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Aug 22, 2024

@sayakpaul
I think so because:

  1. it is better to review that way
  2. we don't need this class in diffusers on its own because it cannot be used yet, no?

@bghira
Copy link
Contributor

bghira commented Aug 22, 2024

sometimes we can make a feature branch where a bunch of PRs can be merged into before one big honkin' PR is pushed to main at the end. and the pieces are all individually reviewed and can be tested. is this a viable approach for including quantisation?

@sayakpaul
Copy link
Member Author

Okay I will update this branch. @yiyixuxu

@SunMarc
Copy link
Member

SunMarc commented Aug 23, 2024

cc @MekkCyber for visibility

@DN6
Copy link
Collaborator

DN6 commented Aug 28, 2024

Just a few considerations for the quantization design.

I would say the initial design should start loading/inference at just the model level and then proceed to add functionality (pipeline level loading etc).

The feature needs to perform the following functions

  1. Perform on the fly quantization of large models so that they can be loaded in a low-memory dtype
    1. with from_pretrained
    2. with from_single_file
  2. Dynamically upcast to the appropriate compute dtype when running inference
  3. Save/Load already quantized versions of these large models (FP8, NF4)
  4. Allow loading/inference with LoRAs in these quantized models. (This we have to figure out in more detail)

At the moment, the most common ask seems to be the ability to load models into GPU using the FP8 dtype and run inference in a supported dtype by dynamically upcasting the necessary layers. NF4 is another format that's gaining attention.

So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.

Some example quantized versions of models that have been doing the rounds

To cover these initial cases, we can rely on Quanto (FP8) and BitsandBytes (NF4).

Example API:

from diffusers import FluxPipeline, FluxTransformer2DModel, DiffusersQuantoConfig

# Load model in FP8 with Quanto and perform compute in configured dtype. 

quantization_config = DiffusersQuantoConfig(weights="float8", compute_dtype=torch.bfloat16)

FluxTransformer2DModel.from_pretrained("<either diffusers format or quanto format weights>", quantization_config=quantization_config)

pipe = FluxPipeline.from_pretrained("...", transformer=transformer)

The quantization config should probably take the following arguments

DiffusersQuantoConfig(
	weights_dtype="", # dtype to store weights
	compute_dtype="", # dtype to perform inference
	skip_quantize_modules=["ResBlock"]
)

I think initially we can rely on the dynamic upcasting operations performed by Quanto and BnB under the hood to start and then expand on them if needed.

Some other considerations

  1. Since we have transformers models in diffusers that can also benefit from quantized loading, we might want to consider adding a Diffusers prefix to the quantization configs. e.g DiffusersQuantoConfig so that when we import quantization configs from transformers there aren't any conflicts.
  2. For saving and loading models we can start with models saved in Quanto/BnB format.
  3. One possible challenge with Pipeline level quantized loading is that we have a mix of transformers/diffusers models. So a single config to quantize/load both types might not be possible.
  4. Single file loading has it's own set of issues, such as dealing with checkpoints that have been naively quantized. This applies to some of the Flux single file checkpoints. e.g. safetensors.torch.save_file(model.to(torch.float8_e4m3fn), "model-fp8.safetensors) and loading full pipeline single file checkpoints. But we can address these later.

@sayakpaul
Copy link
Member Author

sayakpaul commented Aug 28, 2024

This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.

Concretely, I would like to stick to the outline of the changes laid out in #9174 (along with anything related) for this PR.

The feature needs to perform the following functions

I won't advocate doing all of that in a single PR because it makes things very hard to review. We would rather want to move faster with something more minimal, confirming their effectiveness.

Allow loading/inference with LoRAs in these quantized models. (This we have to figure out in more detail)

Well, note that if the underlying LoRA wasn't trained with the base quantization precision, it might not perform as expected.

So perhaps we should focus on this first. This mostly applies to the DiT models but large models like CogVideo might also benefit with this approach.

Please note that bitsandbytes related quantization mostly applies to nn.linear whereas quanto is broader in their scopes (i.e, quanto can be applied to an nn.Conv2D as well).

@DN6
Copy link
Collaborator

DN6 commented Aug 28, 2024

This PR will be at the model-level itself. And we should not add multiple backends in a single PR. This PR aims to add bitsandbytes. We can do other backends taking this PR as a reference. I would like us to mutually agree on this before I start making progress on this PR.

Sounds good to me.

For this PR lets do

  1. from_pretrained only
  2. bnb quantization.

@sayakpaul
Copy link
Member Author

Very insightful comments, @yiyixuxu! I think I have resolved them all. LMK.

}


class DiffusersAutoQuantizationConfig:
Copy link
Collaborator

@DN6 DN6 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is similar to transformers, but I think the DiffusersAutoQuantConfig class is probably not needed.

This is just a simple mapping to a specific quantization config object. The from_pretrained method in the AutoQuantizer is just wrapping the AutoConfig from_pretrained.

I think we can just move these methods/logic directly into the AutoQuantizer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is not a must-have, could do this in a follow-up PR.

src/diffusers/pipelines/pipeline_utils.py Outdated Show resolved Hide resolved
src/diffusers/pipelines/pipeline_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
@ariG23498
Copy link
Contributor

Hi folks!

Thanks for working on this. I was able to run the following script on this branch and generate images on my 8 gigs VRAM laptop

Screenshot from 2024-10-20 13-52-20

from diffusers import FluxPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch
import gc


def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


def bytes_to_giga_bytes(bytes):
    return bytes / 1024 / 1024 / 1024


flush()

ckpt_id = "black-forest-labs/FLUX.1-dev"
ckpt_4bit_id = "sayakpaul/flux.1-dev-nf4-pkg"
prompt = "a billboard on highway with 'FLUX under 8' written on it"

text_encoder_2_4bit = T5EncoderModel.from_pretrained(
    ckpt_4bit_id,
    subfolder="text_encoder_2",
)

pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder_2=text_encoder_2_4bit,
    transformer=None,
    vae=None,
    torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()


with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
        prompt=prompt, prompt_2=None, max_sequence_length=256
    )


pipeline = pipeline.to("cpu")
del pipeline

flush()


transformer_4bit = FluxTransformer2DModel.from_pretrained(ckpt_4bit_id, subfolder="transformer")
pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    transformer=transformer_4bit,
    torch_dtype=torch.float16,
)
pipeline.enable_model_cpu_offload()

print("Running denoising.")
height, width = 512, 768
images = pipeline(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    num_inference_steps=50,
    guidance_scale=5.5,
    height=height,
    width=width,
    output_type="pil",
).images
images[0].save("output.png")

output

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge this!

I asked @DN6 to open a follow-up PR for this #9213 (comment),

@sayakpaul
Copy link
Member Author

PR merge contingent on #9720.

@sayakpaul sayakpaul merged commit b821f00 into main Oct 21, 2024
18 checks passed
@sayakpaul sayakpaul deleted the quantization-config branch October 21, 2024 04:42


@dataclass
class BitsAndBytesConfig(QuantizationConfigMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to consider. Let's assume you want to use a quantized transformer model in your code. With this naming, you would always need to set up imports in the following way.

from transformers import BitsAndBytesConfig
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig

Not a huge issue. Just giving a heads up incase you want to consider renaming the config to something like DiffusersBitsAndBytesConfig

set_module_kwargs["dtype"] = dtype

# bnb params are flattened.
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this situation, aren't we skipping parameter shape checks for bnb loaded weights entirely? What happens when one attempts to load bnb weights but the flattened shape is incorrect?

Perhaps we add a check_quantized_param_shape method to the DiffusersQuantizer base class. And in the BnBQuantizer we can check if the shape matches the rule here:
https://github.com/bitsandbytes-foundation/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L816

Comment on lines +220 to +229
if not is_quantized or (
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
):
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit. IMO this is a bit more readable

        if is_quantized or hf_quantizer.check_quantized_param(
            model, param, param_name, state_dict, param_device=device
        ):
            hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
        else:
            if accepts_dtype:
                set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
            else:
                set_module_tensor_to_device(model, param_name, device, value=param)

"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
return max_memory

def check_quantized_param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO check_is_quantized_param or check_if_quantized_param more explicitly conveys what this method does.



class BnB4BitBasicTests(Base4bitTests):
def setUp(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would clear cache on setup as well.

@Ednaordinary
Copy link

It would be useful to rename llm_int8_skip_modules or otherwise make it more clear that it is respected in both 4bit and 8bit mode, as currently the docs sound like skipped modules are only respected in 8 bit mode while the actual implementation suggests otherwise

if self.quantization_config.llm_int8_skip_modules is not None:

image

@sayakpaul
Copy link
Member Author

Yeah I think the documentation should reflect this. I guess this is safe to do @SunMarc?

@SunMarc
Copy link
Member

SunMarc commented Nov 4, 2024

Yeah we should do that, would you like to update this @Ednaordinary ? We should also do it in transformers when it gets merged.

@Ednaordinary
Copy link

Sure, @SunMarc. I'll make a PR when I'm able. Should I refactor the parameter name and include a deprecation notice, or just include a note in the docs?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.