Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for quantized Conv2d #74

Closed
dacorvo opened this issue Jan 22, 2024 · 15 comments
Closed

Add support for quantized Conv2d #74

dacorvo opened this issue Jan 22, 2024 · 15 comments
Assignees

Comments

@dacorvo
Copy link
Collaborator

dacorvo commented Jan 22, 2024

This layer is required for all computer vision models.

@dacorvo dacorvo self-assigned this Jan 22, 2024
@dacorvo
Copy link
Collaborator Author

dacorvo commented Jan 22, 2024

A first implementation here: https://github.com/huggingface/quanto/tree/qconv2d
The gradient test needs to be fixed (tricky weight gradient calculation), but otherwise it seems to work as expected.

@sayakpaul
Copy link
Member

@dacorvo I have got some findings for you (c.f. internal thread).

To set the context first, unlike transformers models, diffusers deal with "pipelines" (e.g., DiffusionPipeline) that can have multiple nn.Modules.

My setup is our internal audace machine (2 4090s while I used only one by setting CUDA_VISIBLE_DEVICES=0). I am on PyTorch nightly along with CUDA 12.1. I installed quanto from https://github.com/huggingface/quanto/tree/qconv2d.

I tried applying quantize() to the unet and the vae and then benchmark the memory and timing. Here's my script.

Timing and memory:

setting timing (secs) memory (gb)
vanilla fp16 4.016 8.958
weight-only 5.769 8.952
weight & activations 14.046 16.845

For the "weight & activations" setting, I changed the load_pipeline() function in the script like so, because float16 is not supported as said here:

Modified pipeline loading code
def load_pipeline(do_quantize):
    pipeline = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        variant="fp16",
    ).to("cuda")
    pipeline.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix").to("cuda")

    if do_quantize:
        quantize(pipeline.unet, weights=torch.int8, activations=torch.int8)
        quantize(pipeline.vae, weights=torch.int8, activations=torch.int8)

    pipeline.set_progress_bar_config(disable=True)
    return pipeline

The image is just random noise with "weight & activations". Maybe it needs calibration but I didn't try it out. Here's a comparison:

Visual comparison
vanilla fp16 weight-only weight & activations
Vanilla FP16 Weight-only Weight & Activations

Note that I tried disabling activation quantization for the VAE, but that didn't help prevent the issue, either.

Now, I shifted gears to the UNet to test a single model component. Here is my script. Findings are below.

setting timing (secs) memory (gb)
vanilla fp16 0.069 5.053
weight-only 0.170 5.094
weight & activations 5.094 10.032

Let me know if anything is unclear. Also, let me know if you would like me to run other tests.

@dacorvo
Copy link
Collaborator Author

dacorvo commented Jan 29, 2024

@sayakpaul Thank you very much for your tests and feedback.

The results are as expected, but you missed two important things in your tests (I will update the README to make it clearer).

  1. When quantizing activations, you need to calibrate the model with a few samples to find out the correct activation range (otherwise [-1, 1] is assumed). This explains the garbage output you currently get.

You just need to pass a few samples through forward.

with calibration():
    model(samples)
  1. Weights are dynamically quantized by default (i.e. stored in float and quantized on inference) to allow fine-tuning. If you just want to do an inference, you can freeze your model to convert them to integer. This will save on-device memory and significantly speed-up the inference.
freeze(model)

I would be happy to find out the results you get after applying these two changes.

@sayakpaul
Copy link
Member

Thanks much! Let me get back to you after applying these changes.

@sayakpaul
Copy link
Member

Alright. Things seem to be better now (script here):

Timing and memory:

setting timing (secs) memory (gb)
vanilla fp16 4.016 8.958
weight-only 4.614 6.729
weight & activations 11.582 13.375

Observations / questions:

  • Not sure if the timing for weight-only is expected.
  • Can the increased timing for "weight & activations" attributed to the int8 kernels?
  • Why has the memory increased for "weight & activations"? Because we're also calibrating?

Coming to the visual quality:

vanilla fp16 weight-only weight & activations
Vanilla FP16 Weight-only Weight & Activations

I think I am not doing the calibration properly which is why the quality is still degraded. What would you advise?

@dacorvo
Copy link
Collaborator Author

dacorvo commented Jan 29, 2024

@sayakpaul this is indeed much better.

Regarding the activations:

  • quality might be increased by calibrating with more samples (I typically use a batch of at least 128 samples),
  • it may also appear that you actually have outliers in your activations: try to use float8 activations to see if that makes a difference (torch.float8_e4m3fn),
  • I suspect the increased timing is due to the dynamic quantization of the activations, which is not compensated by the int8 matmul speedup. Maybe the matrices are too small ?
  • the increased memory is probably a bug: I might keep references to temporary activations somewhere in the graph.

@sayakpaul
Copy link
Member

Using float8 seems to immediately improve the results:

image

Timing-wise, we're much worse now: 15.500 secs.

I applied more calibration to the earlier setup (non-float8) as well:

image

With calibration and float8:

image

I suspect this will be improved with a bit more calibration. WDYT?

Relevant calibration code
CALIBRATION_PROMPTS = load_dataset("nateraw/parti-prompts", split="train").shuffle(seed=2024).select(range(100))

if args.act_quant:
    chunk_size = 2
    print("Calibrating")
    with Calibration():
        for i in range(0, len(CALIBRATION_PROMPTS), chunk_size):
            _ = pipeline(
                CALIBRATION_PROMPTS[i: i + chunk_size]["Prompt"],
                num_inference_steps=10,
                generator=torch.manual_seed(2024),
            )
    print("Calibration done.")

I suspect the increased timing is due to the dynamic quantization of the activations, which is not compensated by the int8 matmul speedup. Maybe the matrices are too small ?

You may be correct. Actually, this is hinted in https://pytorch.org/blog/accelerating-generative-ai-3/ as well. Let me try out some stuff from there to ensure we're not applying in8 for matrix multiplications that consist of small matrices. Will update this issue thread after running a couple experiments. Will this apply to float8 weight-activation quantization, too? I think yes.

Let me know if you have specific experiments for me run.

@dacorvo
Copy link
Collaborator Author

dacorvo commented Jan 29, 2024

This is super interesting ! First it confirms that float8 are probably the best candidate type for quantizing activations to 8-bit, due to their non-linear representation and wider range.

Unfortunately, as you noticed, there is no hardware support for float8 on most hardware, so it is slower. I might improve this in the future by using specific kernels.

You can pinpoint which specific parts of a model you want to quantize by passing a list of modules to quantize.

Finally, about calibration, you can try adding more samples, but I think it is more important to pass them as batches, so that they are averaged: if you pass them individually, you give too much importance to the first sample (next samples always contribute as (1 - momentum) to the moving average).
If you really need to pass them one by one, you can use a lower momentum in the moving average, but I am not sure it will be better.

@sayakpaul
Copy link
Member

You can pinpoint which specific parts of a model you want to quantize by passing a list of modules to quantize.

Could you provide an example here?

Regarding calibration batches, let’s not forget we’re in the image generation space :D. So, asking a model to generate four 1024x1024 images is just infeasible for most consumer GPUs.

So, it appears to me that quantization still remains a challenging thing in the diffusion world.

I will try to selectively quantize some modules where the matrix shapes are larger as we did in PyTorch post (weight-only quant) and share my findings here.

@dacorvo
Copy link
Collaborator Author

dacorvo commented Jan 29, 2024

Example of selecting module:

quantize(model, modules=[model.lm_head])

@sayakpaul
Copy link
Member

I enabled fusion of the QKV projection matrices to increase the size of the matrices to see if the quantization speed-up becomes more evident. However, I am running into:

Traceback (most recent call last):
  File "/home/sayak/brrr_quanto_diffusers.py", line 72, in <module>
    _ = pipeline(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1216, in __call__
    noise_pred = self.unet(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1121, in forward
    sample, res_samples = downsample_block(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/unets/unet_2d_blocks.py", line 1199, in forward
    hidden_states = attn(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/transformers/transformer_2d.py", line 391, in forward
    hidden_states = block(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/attention.py", line 329, in forward
    attn_output = self.attn1(
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/attention_processor.py", line 512, in forward
    return self.processor(
  File "/home/sayak/diffusers/src/diffusers/models/attention_processor.py", line 1335, in __call__
    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
  File "/home/sayak/quanto/quanto/calibrate.py", line 49, in __torch_function__
    output = func(*args, **kwargs)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 291, in __torch_function__
    return func(*args, **kwargs)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 309, in __torch_dispatch__
    return qdispatch.qop(*args, **kwargs)
  File "/home/sayak/quanto/quanto/tensor/ops.py", line 317, in view
    return qfallback(op, input, *shape)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 160, in qfallback
    args, kwargs = pytree.tree_map_only(QTensor, lambda x: x.dequantize(), (args, kwargs or {}))
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 858, in tree_map_only
    return tree_map(map_only(__type_or_types)(func), tree)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 732, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 599, in unflatten
    leaves = list(leaves)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/utils/_pytree.py", line 818, in wrapped
    return func(x)
  File "/home/sayak/quanto/quanto/tensor/core.py", line 160, in <lambda>
    args, kwargs = pytree.tree_map_only(QTensor, lambda x: x.dequantize(), (args, kwargs or {}))
  File "/home/sayak/quanto/quanto/tensor/core.py", line 259, in dequantize
    return Dequantizer.apply(self)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/sayak/quanto/quanto/tensor/core.py", line 202, in forward
    return (t._scale.to(torch.float32) * t._data).to(t._scale.dtype)
RuntimeError: The size of tensor a (1920) must match the size of tensor b (640) at non-singleton dimension 2

Is this expected?

@dacorvo
Copy link
Collaborator Author

dacorvo commented Jan 30, 2024

Not really: basically the QKV were quantized per-axis/channel with vector scales of shape (640) corresponding to their last dim, but were later concatenated to a tensor whose last dim is (1920) = (640) x 3 (I think). This should have triggered also a concatenation of the scales, but it didn't because the corresponding dispatched op for quantized tensors has not been updated since older versions of quanto and always assumes scalar scales.
https://github.com/huggingface/quanto/blob/5c2a4114eab4dc37208fbc9fff35143453a2017e/quanto/tensor/ops.py#L89
Sorry you hit that bug ...

@sayakpaul
Copy link
Member

Ah cool. Seems like I should wait for a fix?

@sayakpaul
Copy link
Member

This is my pipeline loading function, btw:

def load_pipeline(do_quantize, act_quant):
    pipeline = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        variant="fp16",
        torch_dtype=torch.float32 if act_quant else torch.float16
    ).to("cuda")
    pipeline.vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32 if act_quant else torch.float16
    ).to("cuda")
    pipeline.fuse_qkv_projections()

    if do_quantize:
        quantize(pipeline.unet, weights=torch.int8, activations=torch.float8_e4m3fn if act_quant else None)
        quantize(pipeline.vae, weights=torch.int8, activations=torch.float8_e4m3fn if act_quant else None)
        
        freeze(pipeline.unet)
        freeze(pipeline.vae)

    pipeline.set_progress_bar_config(disable=True)
    return pipeline

I don't think changing the order in which fuse_qkv_projections() is called will make sense?

@dacorvo
Copy link
Collaborator Author

dacorvo commented Feb 20, 2024

Closing as it has been merged in #91

@dacorvo dacorvo closed this as completed Feb 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants