-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for quantized Conv2d #74
Comments
A first implementation here: https://github.com/huggingface/quanto/tree/qconv2d |
@dacorvo I have got some findings for you (c.f. internal thread). To set the context first, unlike My setup is our internal I tried applying Timing and memory:
For the "weight & activations" setting, I changed the Modified pipeline loading codedef load_pipeline(do_quantize):
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
).to("cuda")
pipeline.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix").to("cuda")
if do_quantize:
quantize(pipeline.unet, weights=torch.int8, activations=torch.int8)
quantize(pipeline.vae, weights=torch.int8, activations=torch.int8)
pipeline.set_progress_bar_config(disable=True)
return pipeline The image is just random noise with "weight & activations". Maybe it needs calibration but I didn't try it out. Here's a comparison: Note that I tried disabling activation quantization for the VAE, but that didn't help prevent the issue, either. Now, I shifted gears to the UNet to test a single model component. Here is my script. Findings are below.
Let me know if anything is unclear. Also, let me know if you would like me to run other tests. |
@sayakpaul Thank you very much for your tests and feedback. The results are as expected, but you missed two important things in your tests (I will update the README to make it clearer).
You just need to pass a few samples through forward. with calibration():
model(samples)
freeze(model) I would be happy to find out the results you get after applying these two changes. |
Thanks much! Let me get back to you after applying these changes. |
Alright. Things seem to be better now (script here): Timing and memory:
Observations / questions:
Coming to the visual quality:
I think I am not doing the calibration properly which is why the quality is still degraded. What would you advise? |
@sayakpaul this is indeed much better. Regarding the activations:
|
Using Timing-wise, we're much worse now: 15.500 secs. I applied more calibration to the earlier setup (non-float8) as well: With calibration and I suspect this will be improved with a bit more calibration. WDYT? Relevant calibration codeCALIBRATION_PROMPTS = load_dataset("nateraw/parti-prompts", split="train").shuffle(seed=2024).select(range(100))
if args.act_quant:
chunk_size = 2
print("Calibrating")
with Calibration():
for i in range(0, len(CALIBRATION_PROMPTS), chunk_size):
_ = pipeline(
CALIBRATION_PROMPTS[i: i + chunk_size]["Prompt"],
num_inference_steps=10,
generator=torch.manual_seed(2024),
)
print("Calibration done.")
You may be correct. Actually, this is hinted in https://pytorch.org/blog/accelerating-generative-ai-3/ as well. Let me try out some stuff from there to ensure we're not applying in8 for matrix multiplications that consist of small matrices. Will update this issue thread after running a couple experiments. Will this apply to float8 weight-activation quantization, too? I think yes. Let me know if you have specific experiments for me run. |
This is super interesting ! First it confirms that Unfortunately, as you noticed, there is no hardware support for You can pinpoint which specific parts of a model you want to quantize by passing a list of modules to quantize. Finally, about calibration, you can try adding more samples, but I think it is more important to pass them as batches, so that they are averaged: if you pass them individually, you give too much importance to the first sample (next samples always contribute as (1 - momentum) to the moving average). |
Could you provide an example here? Regarding calibration batches, let’s not forget we’re in the image generation space :D. So, asking a model to generate four 1024x1024 images is just infeasible for most consumer GPUs. So, it appears to me that quantization still remains a challenging thing in the diffusion world. I will try to selectively quantize some modules where the matrix shapes are larger as we did in PyTorch post (weight-only quant) and share my findings here. |
Example of selecting module:
|
I enabled fusion of the QKV projection matrices to increase the size of the matrices to see if the quantization speed-up becomes more evident. However, I am running into:
Is this expected? |
Not really: basically the QKV were quantized per-axis/channel with vector scales of shape (640) corresponding to their last dim, but were later concatenated to a tensor whose last dim is (1920) = (640) x 3 (I think). This should have triggered also a concatenation of the scales, but it didn't because the corresponding dispatched op for quantized tensors has not been updated since older versions of quanto and always assumes scalar scales. |
Ah cool. Seems like I should wait for a fix? |
This is my pipeline loading function, btw: def load_pipeline(do_quantize, act_quant):
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
torch_dtype=torch.float32 if act_quant else torch.float16
).to("cuda")
pipeline.vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float32 if act_quant else torch.float16
).to("cuda")
pipeline.fuse_qkv_projections()
if do_quantize:
quantize(pipeline.unet, weights=torch.int8, activations=torch.float8_e4m3fn if act_quant else None)
quantize(pipeline.vae, weights=torch.int8, activations=torch.float8_e4m3fn if act_quant else None)
freeze(pipeline.unet)
freeze(pipeline.vae)
pipeline.set_progress_bar_config(disable=True)
return pipeline I don't think changing the order in which |
Closing as it has been merged in #91 |
This layer is required for all computer vision models.
The text was updated successfully, but these errors were encountered: