Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 21, 2025

What does this PR do?

Caution

Doesn't work yet.

Test code:

from diffusers import DiffusionPipeline, AutoModel, NunchakuConfig
import torch 

ckpt_id = "black-forest-labs/FLUX.1-dev"
model = AutoModel.from_pretrained(
    ckpt_id, 
    subfolder="transformer",
    torch_dtype=torch.bfloat16, 
    quantization_config=NunchakuConfig()
)
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id, transformer=model, torch_dtype=torch.bfloat16
)
image = pipe(
    "A cat holding a sign that says hello world", 
    num_inference_steps=50, 
    guidance_scale=3.5,
    generator=torch.manual_seed(0),
).images[0]
image.save(f"nunchaku.png")

diffusers-cli env:

- 🤗 Diffusers version: 0.36.0.dev0
- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.8.0.dev20250626+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.4
- Transformers version: 4.53.2
- Accelerate version: 1.10.0.dev0
- PEFT version: 0.17.0
- Bitsandbytes version: 0.46.0
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@lmxyy I am going to outline the stage we're currently at in this integration as that will help us better understand the blockers.

Cc: @SunMarc

@sayakpaul
Copy link
Member Author

Let me outline the stage where we're currently at, as this will help us understand the current blockers:

This is for quantizing a pre-trained non-quantized model checkpoint as opposed to trying to directly load a quantized checkpoint.

If you have suggestions, please LMK.

@sayakpaul
Copy link
Member Author

Discussed with @SunMarc internally. We will also try to first supported pre-quantized checkpoints from https://huggingface.co/nunchaku-tech/nunchaku and see how it goes.

@sayakpaul
Copy link
Member Author

Tried a bit for loading pre-quantized checkpoints. The issues currently are:

  • The prequantized checkpoint (example) has mlp_fc* keys which aren't present in our implementation for Flux. This needs to be accounted for.
  • It uses horizontal fusion for attention in the checkpoints -- something we don't support in our implementation yet. This will also need to be accounted for.
Code
from diffusers import DiffusionPipeline, FluxTransformer2DModel, NunchakuConfig
from nunchaku.models.linear import SVDQW4A4Linear
from safetensors import safe_open
from huggingface_hub import hf_hub_download
import torch 


def modules_without_qweight(safetensors_path: str):
    no_qweight = set()
    with safe_open(safetensors_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            if key.endswith(".weight"):
                # module name is everything except the last piece after "."
                module_name = ".".join(key.split(".")[:-1])
                no_qweight.add(module_name)
    return sorted(no_qweight)

ckpt_id = "black-forest-labs/FLUX.1-dev"
state_dict_path = hf_hub_download(repo_id="nunchaku-tech/nunchaku-flux.1-dev", filename="svdq-int4_r32-flux.1-dev.safetensors")
modules_to_not_convert = modules_without_qweight(state_dict_path)
# print(f"{modules_to_convert=}")

model = FluxTransformer2DModel.from_single_file(
    state_dict_path,
    config=ckpt_id, 
    subfolder="transformer",
    torch_dtype=torch.bfloat16, 
    quantization_config=NunchakuConfig(
        weight_dtype="int4",
        weight_group_size=64,
        activation_dtype="int4",
        activation_group_size=64,
        modules_to_not_convert=modules_to_not_convert
    )
).to("cuda")
has_svd = any(isinstance(module, SVDQW4A4Linear) for _, module in model.named_modules())
assert has_svd

pipe = DiffusionPipeline.from_pretrained(
    ckpt_id, transformer=model, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
    "A cat holding a sign that says hello world", 
    num_inference_steps=50, 
    guidance_scale=3.5,
    generator=torch.manual_seed(0),
).images[0]
image.save(f"nunchaku.png")

Cc: @SunMarc

@lmxyy
Copy link
Contributor

lmxyy commented Aug 22, 2025

the conversion can be found here: https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395

@dxqb
Copy link

dxqb commented Sep 25, 2025

* However, there doesn't seem to be a method in `nunchaku` that can quantize a pre-trained parameter. This is the current blocker. So, simply doing the following isn't supposed to work as expected:
  https://github.com/huggingface/diffusers/blob/nunchaku/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py#L110-L136

I raised this issue with them here nunchaku-tech/nunchaku#687 and they referred to https://github.com/nunchaku-tech/deepcompressor having the quantization helper functions.

I have not looked into it further, because there were more blocking issues, such as memory corruptions, and I am unsure whether AWQW4A16Linear is currently a general purpose tool that can be used outside of nunchaku.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants