Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ Integration #771

Merged
merged 15 commits into from
Aug 11, 2023
Merged

GPTQ Integration #771

merged 15 commits into from
Aug 11, 2023

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Aug 1, 2023

What does this PR do ?

This PR adds the possibility to train lora + adalora adapters on top of GPTQ quantized model.

convert to peft model for training

causal_lm_model_id = "marcsun13/opt-350m-gptq-4bit"
tokenizer  = AutoTokenizer.from_pretrained(causal_lm_model_id)
model = AutoModelForCausalLM.from_pretrained(
    causal_lm_model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
      r=16,
      lora_alpha=32,
      target_modules=["q_proj", "v_proj"],
      lora_dropout=0.05,
      bias="none",
      task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

save adapters after training

model.cpu().save_pretrained(save_dir)

load saved adapters

model = AutoModelForCausalLM.from_pretrained(
    causal_lm_model_id,
    torch_dtype=torch.float16,
    device_map="auto"
)
model = PeftModel.from_pretrained(model ,save_dir)# load saved adapters

to do

  • finetune llama2
  • merge after transformers PR (the doc PR test will then be fixed)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 1, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @SunMarc, Thank you for adding AutoGPTQ support 🚀. Left few comments

src/peft/tuners/lora.py Outdated Show resolved Hide resolved
src/peft/tuners/lora.py Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. This PR looks really good, I only have minor comments. I don't have any experience with GPTQ itself (yet), so I cannot really judge the more technical parts of the implementation.

A more general question: Does GPTQ generally not work with IA³ or is it just a matter of implementing it later?

src/peft/tuners/adalora.py Outdated Show resolved Hide resolved
src/peft/tuners/ia3.py Outdated Show resolved Hide resolved
src/peft/tuners/lora.py Outdated Show resolved Hide resolved
src/peft/tuners/lora.py Outdated Show resolved Hide resolved
@SunMarc SunMarc requested a review from pacman100 August 2, 2023 14:29
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great, thanks a lot @SunMarc ! I left one comment
We can add docs and an example script later in a follow up PR
Thanks! 🚀

docker/peft-gpu/Dockerfile Show resolved Hide resolved
@SunMarc SunMarc requested a review from BenjaminBossan August 9, 2023 23:20
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is looking already quite good from my POV. Unfortunately, the tests are still failing, I assume because they require the corresponding changes to land in transformers.

I have a few comments, but none of them are deal breakers.

src/peft/tuners/adalora.py Outdated Show resolved Hide resolved
src/peft/tuners/adalora.py Outdated Show resolved Hide resolved
src/peft/tuners/adalora.py Outdated Show resolved Hide resolved
LoraLayer.__init__(
self, in_features=quant_linear_module.infeatures, out_features=quant_linear_module.outfeatures
)
self.quant_linear_module = quant_linear_module
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting deviation from how the other lora layers are implemented. Here, we pass the original layer (quant_linear_module) and use it under the hood. For the normal Linear lora layer, we don't get the layer, instead basically creating a new linear layer:

# in __init__
nn.Linear.__init__(self, in_features, out_features, **kwargs)
# in forward
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

I actually prefer the solution here but wonder if there was a specific reason why this approach was not taken originally. If so, would that same reason apply here or are we good with having two different approaches? Hopefully, the others can clarify this.

Copy link
Member Author

@SunMarc SunMarc Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did that because I wanted to put the new QuantLinear class in the same place as the others for conformity. If we want to go with the same approach, we will need to create this new class in a function so that the auto_gptq import is protected ( to avoid circular import as it is also importing peft ) . LMK what you think about this solution and I will add it in another PR.

tests/test_gpu_examples.py Show resolved Hide resolved
@SunMarc
Copy link
Member Author

SunMarc commented Aug 10, 2023

Unfortunately, the tests are still failing, I assume because they require the corresponding changes to land in transformers.

I don't want to break the tests so i hardcoded the const value. I will change them back when we will have the next release of transformers

@SunMarc SunMarc merged commit a916465 into huggingface:main Aug 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants