-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT: add awq suppot in PEFT #1399
Merged
Merged
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
35b1dec
add awq suppot in PEFT
younesbelkada c08b98e
fix
younesbelkada 7490057
fux
younesbelkada 9283800
Update src/peft/tuners/lora/awq.py
younesbelkada a461ffb
Merge remote-tracking branch 'origin/main' into add-awq
younesbelkada b925c17
style & fix tests
younesbelkada b7ac85f
forward contrib credits from PR14084
s4rduk4r 02d6eca
forward contrib credits from autoawq PR
s4rduk4r 616aefe
change name
younesbelkada c05feec
fix
younesbelkada 4684699
change to peft internal testing
younesbelkada fcd51b9
fix
younesbelkada 87b677f
fix
younesbelkada 4f22260
add multi-GPU tests
younesbelkada 1d52a45
add to dockerfile
younesbelkada 94fe0b1
Merge branch 'add-awq' of https://github.com/huggingface/peft into ad…
younesbelkada 07c0486
fix todo
younesbelkada 2f93c82
raise error only at the dispatch level
younesbelkada ec37ff4
quality
younesbelkada ec422d1
fix test
younesbelkada 198f564
fix dockerfile
younesbelkada f5123b5
fix
younesbelkada 33d8f11
fix
younesbelkada 47f006d
update dockerfile and tests
younesbelkada File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# coding=utf-8 | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Optional | ||
|
||
import torch | ||
|
||
from peft.import_utils import is_auto_awq_available | ||
from peft.tuners.lora.layer import LoraLayer | ||
from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
||
|
||
if is_auto_awq_available(): | ||
from awq.modules.linear import WQLinear_GEMM as AWQ_WQLinear_GEMM | ||
|
||
|
||
class WQLinear_GEMM(torch.nn.Module, LoraLayer): | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
base_layer, | ||
adapter_name, | ||
r: int = 0, | ||
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
init_lora_weights: bool = True, | ||
use_rslora: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
LoraLayer.__init__(self, base_layer) | ||
|
||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter | ||
# for backwards compatibility | ||
self.quant_linear_module = base_layer | ||
|
||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||
|
||
def forward(self, x: torch.Tensor): | ||
result = self.quant_linear_module(x) | ||
|
||
if self.disable_adapters: | ||
return result | ||
|
||
for active_adapter in self.active_adapters: | ||
if active_adapter not in self.lora_A.keys(): | ||
continue | ||
lora_A = self.lora_A[active_adapter] | ||
lora_B = self.lora_B[active_adapter] | ||
dropout = self.lora_dropout[active_adapter] | ||
scaling = self.scaling[active_adapter] | ||
|
||
requires_conversion = not torch.is_autocast_enabled() | ||
if requires_conversion: | ||
expected_dtype = result.dtype | ||
x = x.to(lora_A.weight.dtype) | ||
|
||
output = lora_B(lora_A(dropout(x))) | ||
if requires_conversion: | ||
output = output.to(expected_dtype) | ||
output = output * scaling | ||
result += output | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return result | ||
|
||
def __repr__(self) -> str: | ||
rep = super().__repr__() | ||
return "lora." + rep | ||
|
||
|
||
def dispatch_awq( | ||
target: torch.nn.Module, | ||
adapter_name: str, | ||
**kwargs: Any, | ||
) -> Optional[torch.nn.Module]: | ||
new_module = None | ||
|
||
if isinstance(target, BaseTunerLayer): | ||
target_base_layer = target.get_base_layer() | ||
else: | ||
target_base_layer = target | ||
|
||
if isinstance(target_base_layer, AWQ_WQLinear_GEMM): | ||
new_module = WQLinear_GEMM(target, adapter_name, **kwargs) | ||
target.qweight = target_base_layer.qweight | ||
|
||
return new_module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not quite clear to me: change it when?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I'll remove it