-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Corrupted outputs with Marlin int4 kernels as parallelization increases #332
Comments
Could be related, but I noticed the latest release of optimum-quanto (v0.2.5) corrupts transformer weights during qfloat8 quantization. Downgrading to 0.2.4 solved this issue. Not sure what the exact cause is but will look into it Code that caused corruption in 0.2.5 but not earlier versions: pipe = FluxPipeline.from_pretrained(...
quantize(pipe.transformer, weights=qfloat8)
freeze(pipe.transformer)
quantize(pipe.text_encoder, weights=qfloat8)
freeze(pipe.text_encoder)
quantize(pipe.text_encoder_2, weights=qfloat8)
freeze(pipe.text_encoder_2) |
Yeah, same here. I was confused at first because the generated image was just pure noise so I downgraded to this version and it worked fine. (This was the 0.25.0.dev0) |
@inarikami @Leommm-byte this cannot be related, as the new Marlin kernel is only available for |
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
When using MarlinInt4WeightQBitsTensor and its associated optimized gemm kernel, there are issues with the weight/scales/zero-point readback as soon as parallelization increases.
The consequence is that output features higher than 128 are corrupted when a sufficient amount of inputs are parallelized.
Test to reproduce the issue here:
optimum-quanto/test/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py
Line 134 in 852bb9c
The text was updated successfully, but these errors were encountered: