Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrupted outputs with Marlin int4 kernels as parallelization increases #332

Open
dacorvo opened this issue Oct 6, 2024 · 4 comments
Open
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@dacorvo
Copy link
Collaborator

dacorvo commented Oct 6, 2024

When using MarlinInt4WeightQBitsTensor and its associated optimized gemm kernel, there are issues with the weight/scales/zero-point readback as soon as parallelization increases.

The consequence is that output features higher than 128 are corrupted when a sufficient amount of inputs are parallelized.

Test to reproduce the issue here:

@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False)

@dacorvo dacorvo added bug Something isn't working help wanted Extra attention is needed labels Oct 6, 2024
@inarikami
Copy link

inarikami commented Oct 12, 2024

Could be related, but I noticed the latest release of optimum-quanto (v0.2.5) corrupts transformer weights during qfloat8 quantization. Downgrading to 0.2.4 solved this issue. Not sure what the exact cause is but will look into it

Code that caused corruption in 0.2.5 but not earlier versions:

pipe = FluxPipeline.from_pretrained(...

quantize(pipe.transformer, weights=qfloat8)
freeze(pipe.transformer)

quantize(pipe.text_encoder, weights=qfloat8)
freeze(pipe.text_encoder)

quantize(pipe.text_encoder_2, weights=qfloat8)
freeze(pipe.text_encoder_2)

@Leommm-byte
Copy link

Could be related, but I noticed the latest release of optimum-quanto (v0.2.5) corrupts transformer weights during qfloat8 quantization. Downgrading to 0.2.4 solved this issue. Not sure what the exact cause is but will look into it

Code that caused corruption in 0.2.5 but not earlier versions:

pipe = FluxPipeline.from_pretrained(...

quantize(pipe.transformer, weights=qfloat8)
freeze(pipe.transformer)

quantize(pipe.text_encoder, weights=qfloat8)
freeze(pipe.text_encoder)

quantize(pipe.text_encoder_2, weights=qfloat8)
freeze(pipe.text_encoder_2)

Yeah, same here. I was confused at first because the generated image was just pure noise so I downgraded to this version

https://github.com/huggingface/optimum-quanto.git@65ace79d6af6ccc27afbb3576541cc36b3e3a98b

and it worked fine. (This was the 0.25.0.dev0)

@dacorvo
Copy link
Collaborator Author

dacorvo commented Oct 14, 2024

@inarikami @Leommm-byte this cannot be related, as the new Marlin kernel is only available for qint4 and is not used by default.
If you have a script allowing to reproduce the corruption on 0.0.25, feel free to open an issue.

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Nov 14, 2024
@dacorvo dacorvo removed the Stale label Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants