Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX GPTQModel Lora Wrapper #2404

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

FIX GPTQModel Lora Wrapper #2404

wants to merge 14 commits into from

Conversation

Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Feb 28, 2025

PR Changes:

  • FIX GPTQ linear layers from GPTQModel is not compatible with PEFT Lora wrapper
  • Skip Scaling multiply ops if Scaling == 1 (Unsure gpu is smart enough to no-op this micro optimizaton so doing this manually). The loras we are testing for GPTQ does not use Scale so scale = r / lora_alpha where r == lora_alpha

Notes:

  • GPTQLoraLinear copies most of the code and structure from AwqLoraLinear.

TODO:

  • Add CI test for GPTQmodel + Lora

@Qubitium Qubitium marked this pull request as draft February 28, 2025 05:34
@Qubitium Qubitium changed the title [WIP] FIX GPTQ Lora Wrapper [WIP] FIX GPTQModel Lora Wrapper Feb 28, 2025
@Qubitium Qubitium changed the title [WIP] FIX GPTQModel Lora Wrapper FIX GPTQModel Lora Wrapper Feb 28, 2025
@Qubitium
Copy link
Contributor Author

Qubitium commented Feb 28, 2025

@BenjaminBossan @SunMarc PR ready for review. My co-worker is writing up the peft ci-test for this but I want to get the review started early, if possible, before test is ready. Our GPTQmodel tests with PEFT and Lora test is passing.

This PR needs to to pair with GPTQModel PR: ModelCloud/GPTQModel#1358 which has a new ci test for lora. We are also rounding out the tests on our side and will merge and release v2.0 today or tomorrow.

@Qubitium Qubitium marked this pull request as ready for review February 28, 2025 05:54
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Left a couple of comments

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on these fixes and improvements to GPTQ. I have a couple of small comments, please check.

FIX GPTQ linear layers from GPTQModel is not compatible with PEFT Lora wrapper

What exactly was broken?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch. No wonder my ide can't find it. These types of dynamic code makes things real tough. Python's find_usage is nearly useless here.

I agree that this type of dynamism should be avoided if possible. However, I don't really see how else this can be implemented in a generic fashion, since each PEFT method can have different relevant attributes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Qubitium
Copy link
Contributor Author

@BenjaminBossan Looks like everything is cleared for merge. But please don't merge until we have added CI tests.

Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
@Qubitium Qubitium requested a review from BenjaminBossan March 3, 2025 09:40
@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 3, 2025

@BenjaminBossan This PR is green on our end. Need to be paired with GPTQModel main to run. Let me know if you have another requests. We would love to this get this merged as we may have even more PRs in the pipeline. =)

Also, is there a planned date for a new official release of PEFT on pypi?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. Overall, this looks good, I just added a few minor comments.

A bigger issue for now is the requirement for an as yet unreleased gptqmodel version. What is the plan here, do we want to wait for that release first before proceeding with this PR?

Also a minor note, building the wheel is relatively slow, do you have prebuilt wheels?

Also, is there a planned date for a new official release of PEFT on pypi?

We just rolled out a bigger feature which requires a few follow up changes, after that I think we can do a release. I'd say probably start of next week.

@@ -52,7 +52,7 @@ def is_auto_gptq_available():
@lru_cache
def is_gptqmodel_available():
if importlib.util.find_spec("gptqmodel") is not None:
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.7.0")
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.9.99")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently problematic, as there is no gptqmodel release that satisfies this version requirement. Therefore, is_gptqmodel_available will always raise an error. Since is_gptqmodel_available is called by require_auto_gptq, even this check will always fail, meaning that our GPU tests cannot run at all.

I think require_auto_gptq should be adjusted to not fail if the installed gptqmodel version is too low.


@staticmethod
def test_load_lora():
model_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a smaller model that could be used here? That would reduce the risk of getting a network timeout or full disk error on CI. If not, we can try how this one works out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything below 1B have massive quantization errors and would cause inference to be highly unstable which cause ci tests to be wildly unstable too. Let's stick with 1B unless we get errors.

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 3, 2025

Thanks for the updates. Overall, this looks good, I just added a few minor comments.

A bigger issue for now is the requirement for an as yet unreleased gptqmodel version. What is the plan here, do we want to wait for that release first before proceeding with this PR?

Also a minor note, building the wheel is relatively slow, do you have prebuilt wheels?

I am cutting and release GPTQModel v2.0 tonight.

We also build and auto-download prebuild wheels for our pypi release. But post source release to pypi, it takes like 4-6 hours to generate all the different wheels.

I will work to clear off your suggestions and we can proceed to merge. v2.0 will be release in the next 2-3 hours, followed by downloadable wheels another 3-4 hours after that.

Comment on lines +55 to +56
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("2.0.0")
OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.24.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan Updated versions.

@@ -374,4 +371,4 @@ def test_load_lora():
tokens = model.generate(**inp)[0]
result = tokenizer.decode(tokens)

print("result: ", result)
assert "paris" in result.lower()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan We now assert on expected output. 1B is pretty stable so don't have to worry about flakiness.

Copy link
Contributor Author

@Qubitium Qubitium left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan Ci test updated with requested changes.

However there is a AdaLora test failure that we inherited. Have no idea what it does. The test is failing at the AdaLora config state so there is probably changes to PEFT that is making it no llonger compatible.

How to fix this?

FAILED test_gptqmodel.py::PeftGPTQModelTests::test_adalora_causalLM - ValueError: AdaLoRA does not work when `total_step` is None, supply a value > 0.

@BenjaminBossan
Copy link
Member

Ci test updated with requested changes.

Nice. Let's wait for the release and the wheels and I'll give it a spin.

However there is a AdaLora test failure that we inherited. Have no idea what it does. The test is failing at the AdaLora config state so there is probably changes to PEFT that is making it no llonger compatible.

How to fix this?

Could you try adding total_step=6 when the AdaLoraConfig is initialized?

@Qubitium
Copy link
Contributor Author

Qubitium commented Mar 3, 2025

Ci test updated with requested changes.

Nice. Let's wait for the release and the wheels and I'll give it a spin.

However there is a AdaLora test failure that we inherited. Have no idea what it does. The test is failing at the AdaLora config state so there is probably changes to PEFT that is making it no llonger compatible.
How to fix this?

Could you try adding total_step=6 when the AdaLoraConfig is initialized?

@BenjaminBossan Pushed AdaLora fix. Had to tweak total_steps, init, and final_init values to get it to pass.

GPTQModel v2.0.0 just released. Wheels are building: https://github.com/ModelCloud/GPTQModel/actions/runs/13641508684/job/38132256878

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants