From f5ce3fa1b4bb6fe0bc16c41152545789f8047354 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 01:01:46 +0000 Subject: [PATCH 01/15] test gptq lora --- src/peft/import_utils.py | 19 ++++++--- src/peft/tuners/lora/gptq.py | 83 +++++++++++++++++++++++++++++++++--- src/peft/utils/other.py | 2 +- 3 files changed, 90 insertions(+), 14 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index e322ab778d..a3cfb8838e 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -13,12 +13,14 @@ # limitations under the License. import importlib import importlib.metadata as importlib_metadata +import logging import platform from functools import lru_cache import packaging.version import torch +log = logging.getLogger(__name__) @lru_cache def is_bnb_available() -> bool: @@ -50,9 +52,9 @@ def is_auto_gptq_available(): @lru_cache -def is_gptqmodel_available(): +def is_gptqmodel_available(prompt_install: bool = False): if importlib.util.find_spec("gptqmodel") is not None: - GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.7.0") + GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.9.0") OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99") version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel")) if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel: @@ -62,18 +64,21 @@ def is_gptqmodel_available(): return True else: raise ImportError( - f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher. Found version {version_optimum}, " - f"but only versions above {OPTIMUM_MINIMUM_VERSION} are supported" + f"gptqmodel requires optimum version `{OPTIMUM_MINIMUM_VERSION}` or higher. Found version `{version_optimum}`, " + f"but only versions above `{OPTIMUM_MINIMUM_VERSION}` are supported" ) else: raise ImportError( - f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher to be installed." + f"gptqmodel requires optimum version `{OPTIMUM_MINIMUM_VERSION}` or higher to be installed." ) else: raise ImportError( - f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, " - f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported" + f"Found an incompatible version of gptqmodel. Found version `{version_gptqmodel}`, " + f"but only versions above `{GPTQMODEL_MINIMUM_VERSION}` are supported" ) + elif prompt_install: + log.info("Please install GPTQModel for required functionality: `pip install -U gptqmodel --no-build-isolation -v`.") + @lru_cache diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 4208265c4f..7b8689483c 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -15,6 +15,7 @@ from typing import Any, Optional import torch +from gptqmodel.nn_modules.qlinear import BaseQuantLinear from peft.import_utils import is_gptqmodel_available from peft.tuners.lora.layer import LoraLayer @@ -109,14 +110,84 @@ def dispatch_gptq( cfg = kwargs.get("gptq_quantization_config", None) - if is_gptqmodel_available(): - device_map = kwargs.get("device_map", None) - quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map) + if is_gptqmodel_available(prompt_install=True): + from gptqmodel.nn_modules.qlinear import BaseQuantLinear + + if isinstance(target_base_layer, BaseQuantLinear): + new_module = GPTQLoraLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight else: quant_linear = get_auto_gptq_quant_linear(cfg) - if quant_linear is not None and isinstance(target_base_layer, quant_linear): - new_module = QuantLinear(target, adapter_name, **kwargs) - target.qweight = target_base_layer.qweight + if quant_linear is not None and isinstance(target_base_layer, quant_linear): + new_module = QuantLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight return new_module + + +class GPTQLoraLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + lora_bias: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + lora_bias=lora_bias, + ) + + def forward(self, x: torch.Tensor): + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = self._cast_input_dtype(x, lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep \ No newline at end of file diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index e03a4f171a..3d10b322bd 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -934,7 +934,7 @@ def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None): if gptq_quantization_config is None: return None - if not is_gptqmodel_available(): + if not is_gptqmodel_available(prompt_install=True): return None from gptqmodel.utils.importer import hf_select_quant_linear From 65f9583b2b23ce68ec82a80d3a7144add6952573 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 05:25:41 +0000 Subject: [PATCH 02/15] skip scaling ops if scaling == 1 --- src/peft/tuners/lora/gptq.py | 5 ++++- src/peft/tuners/lora/layer.py | 12 ++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 7b8689483c..dfd7dd4090 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -184,7 +184,10 @@ def forward(self, x: torch.Tensor): output = lora_B(lora_A(dropout(x))) if requires_conversion: output = output.to(expected_dtype) - output = output * scaling + + if scaling != 1: + output = output * scaling + result = result + output return result diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 493b7b1852..c61f3400f2 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -489,7 +489,12 @@ def _mixed_batch_forward( # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear # layer output sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype) - lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling + + if scaling == 1: # no scaling + lora_output = lora_B(lora_A(dropout(sub_batch))) + else: + lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling + result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype) return result @@ -721,7 +726,10 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: - result = result + lora_B(lora_A(dropout(x))) * scaling + if scaling == 1: # no scaling + result = result + lora_B(lora_A(dropout(x))) + else: + result = result + lora_B(lora_A(dropout(x))) * scaling else: if isinstance(dropout, nn.Identity) or not self.training: base_result = result From 7bc48176c7053c9f2c4d16f1310c5c6b4c11fd73 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 05:43:21 +0000 Subject: [PATCH 03/15] format --- src/peft/import_utils.py | 7 +++++-- src/peft/tuners/lora/gptq.py | 5 ++--- src/peft/tuners/lora/layer.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index a3cfb8838e..406183a568 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -20,8 +20,10 @@ import packaging.version import torch + log = logging.getLogger(__name__) + @lru_cache def is_bnb_available() -> bool: return importlib.util.find_spec("bitsandbytes") is not None @@ -77,8 +79,9 @@ def is_gptqmodel_available(prompt_install: bool = False): f"but only versions above `{GPTQMODEL_MINIMUM_VERSION}` are supported" ) elif prompt_install: - log.info("Please install GPTQModel for required functionality: `pip install -U gptqmodel --no-build-isolation -v`.") - + log.info( + "Please install GPTQModel for required functionality: `pip install -U gptqmodel --no-build-isolation -v`." + ) @lru_cache diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index dfd7dd4090..de8856cb6d 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -15,12 +15,11 @@ from typing import Any, Optional import torch -from gptqmodel.nn_modules.qlinear import BaseQuantLinear from peft.import_utils import is_gptqmodel_available from peft.tuners.lora.layer import LoraLayer from peft.tuners.tuners_utils import BaseTunerLayer -from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear +from peft.utils import get_auto_gptq_quant_linear class QuantLinear(torch.nn.Module, LoraLayer): @@ -193,4 +192,4 @@ def forward(self, x: torch.Tensor): def __repr__(self) -> str: rep = super().__repr__() - return "lora." + rep \ No newline at end of file + return "lora." + rep diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index c61f3400f2..d2f72da7ea 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -490,7 +490,7 @@ def _mixed_batch_forward( # layer output sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype) - if scaling == 1: # no scaling + if scaling == 1: # no scaling lora_output = lora_B(lora_A(dropout(sub_batch))) else: lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling @@ -726,7 +726,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: - if scaling == 1: # no scaling + if scaling == 1: # no scaling result = result + lora_B(lora_A(dropout(x))) else: result = result + lora_B(lora_A(dropout(x))) * scaling From e0b069d75ae8300031eeee74b67235d2d935e1fb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 11:27:53 +0000 Subject: [PATCH 04/15] use QuantLinear and rename to GPTQLoraLinear + micro optimizations --- src/peft/tuners/lora/__init__.py | 4 +- src/peft/tuners/lora/gptq.py | 97 +++++++------------------------- 2 files changed, 23 insertions(+), 78 deletions(-) diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index 779d4eec79..70036879d4 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -17,7 +17,7 @@ from .config import EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig from .eva import get_eva_state_dict, initialize_lora_eva_weights -from .gptq import QuantLinear +from .gptq import GPTQLoraLinear from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer from .model import LoraModel @@ -27,13 +27,13 @@ "Conv3d", "Embedding", "EvaConfig", + "GPTQLoraLinear", "Linear", "LoftQConfig", "LoraConfig", "LoraLayer", "LoraModel", "LoraRuntimeConfig", - "QuantLinear", "get_eva_state_dict", "initialize_lora_eva_weights", ] diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index de8856cb6d..b70b754140 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from functools import lru_cache from typing import Any, Optional import torch @@ -22,7 +22,7 @@ from peft.utils import get_auto_gptq_quant_linear -class QuantLinear(torch.nn.Module, LoraLayer): +class GPTQLoraLinear(torch.nn.Module, LoraLayer): def __init__( self, base_layer, @@ -57,6 +57,11 @@ def __init__( lora_bias=lora_bias, ) + @lru_cache + def _adapter_in_lora_keys(self, adapter): + # only need to check lora_A as result is same for lora_B + return adapter in self.lora_A.keys() + def forward(self, x: torch.Tensor): # note: logic differs from default Linear because merging is not supported result = self.quant_linear_module(x) @@ -65,8 +70,9 @@ def forward(self, x: torch.Tensor): return result for active_adapter in self.active_adapters: - if active_adapter not in self.lora_A.keys(): + if not self._adapter_in_lora_keys(active_adapter): continue + lora_A = self.lora_A[active_adapter] lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] @@ -77,10 +83,19 @@ def forward(self, x: torch.Tensor): expected_dtype = result.dtype x = self._cast_input_dtype(x, lora_A.weight.dtype) - output = lora_B(lora_A(dropout(x))) + # lora_dropout float value is not stored so we need to check for cls + if isinstance(dropout, torch.nn.Dropout): + output = lora_B(lora_A(dropout(x))) + else: + # dropout == Identity which is no-op if lora_dropout == 0.0 + output = lora_B(lora_A(x)) + if requires_conversion: output = output.to(expected_dtype) - output = output * scaling + + if scaling != 1: # skip scaling == 1 no-op + output = output * scaling + result += output return result @@ -119,77 +134,7 @@ def dispatch_gptq( quant_linear = get_auto_gptq_quant_linear(cfg) if quant_linear is not None and isinstance(target_base_layer, quant_linear): - new_module = QuantLinear(target, adapter_name, **kwargs) + new_module = GPTQLoraLinear(target, adapter_name, **kwargs) target.qweight = target_base_layer.qweight return new_module - - -class GPTQLoraLinear(torch.nn.Module, LoraLayer): - def __init__( - self, - base_layer, - adapter_name, - r: int = 0, - lora_alpha: int = 1, - lora_dropout: float = 0.0, - init_lora_weights: bool = True, - use_rslora: bool = False, - use_dora: bool = False, - lora_bias: bool = False, - **kwargs, - ): - if use_dora: - raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") - - super().__init__() - LoraLayer.__init__(self, base_layer) - - # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter - # for backwards compatibility - self.quant_linear_module = base_layer - - self._active_adapter = adapter_name - self.update_layer( - adapter_name, - r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - init_lora_weights=init_lora_weights, - use_rslora=use_rslora, - use_dora=use_dora, - lora_bias=lora_bias, - ) - - def forward(self, x: torch.Tensor): - result = self.quant_linear_module(x) - - if self.disable_adapters: - return result - - for active_adapter in self.active_adapters: - if active_adapter not in self.lora_A.keys(): - continue - lora_A = self.lora_A[active_adapter] - lora_B = self.lora_B[active_adapter] - dropout = self.lora_dropout[active_adapter] - scaling = self.scaling[active_adapter] - - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - x = self._cast_input_dtype(x, lora_A.weight.dtype) - - output = lora_B(lora_A(dropout(x))) - if requires_conversion: - output = output.to(expected_dtype) - - if scaling != 1: - output = output * scaling - - result = result + output - return result - - def __repr__(self) -> str: - rep = super().__repr__() - return "lora." + rep From c5b07bb581e0cf17d14caacf817d3aae0d20b4d0 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 11:55:52 +0000 Subject: [PATCH 05/15] optimize --- src/peft/tuners/lora/gptq.py | 2 +- src/peft/tuners/lora/layer.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index b70b754140..96edf56f1d 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -58,7 +58,7 @@ def __init__( ) @lru_cache - def _adapter_in_lora_keys(self, adapter): + def _adapter_in_lora_keys(self, adapter: str): # only need to check lora_A as result is same for lora_B return adapter in self.lora_A.keys() diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index d2f72da7ea..079f37477c 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -15,6 +15,7 @@ import math import warnings +from functools import lru_cache from typing import Any, Optional, Union import torch @@ -701,6 +702,11 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor + @lru_cache + def _adapter_in_lora_keys(self, adapter: str): + # only need to check lora_A as result is same for lora_B + return adapter in self.lora_A.keys() + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) @@ -717,8 +723,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: result = self.base_layer(x, *args, **kwargs) torch_result_dtype = result.dtype for active_adapter in self.active_adapters: - if active_adapter not in self.lora_A.keys(): + if not self._adapter_in_lora_keys(active_adapter): continue + lora_A = self.lora_A[active_adapter] lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] @@ -727,9 +734,15 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if not self.use_dora[active_adapter]: if scaling == 1: # no scaling - result = result + lora_B(lora_A(dropout(x))) + if isinstance(dropout, nn.Dropout): + result = result + lora_B(lora_A(dropout(x))) + else: + result = result + lora_B(lora_A(x)) else: - result = result + lora_B(lora_A(dropout(x))) * scaling + if isinstance(dropout, nn.Dropout): + result = result + lora_B(lora_A(dropout(x))) * scaling + else: + result = result + lora_B(lora_A(x)) * scaling else: if isinstance(dropout, nn.Identity) or not self.training: base_result = result From 7421210ca3694f0767f66d771fb8a5fbc62f8384 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 14:11:49 +0000 Subject: [PATCH 06/15] remove logger.info and add comments on scaling == 1 --- src/peft/import_utils.py | 10 +--------- src/peft/tuners/lora/gptq.py | 2 +- src/peft/tuners/lora/layer.py | 6 ++++-- src/peft/tuners/tuners_utils.py | 2 +- src/peft/utils/other.py | 2 +- 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 406183a568..76ca46cc91 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import importlib import importlib.metadata as importlib_metadata -import logging import platform from functools import lru_cache @@ -21,9 +20,6 @@ import torch -log = logging.getLogger(__name__) - - @lru_cache def is_bnb_available() -> bool: return importlib.util.find_spec("bitsandbytes") is not None @@ -54,7 +50,7 @@ def is_auto_gptq_available(): @lru_cache -def is_gptqmodel_available(prompt_install: bool = False): +def is_gptqmodel_available(): if importlib.util.find_spec("gptqmodel") is not None: GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.9.0") OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99") @@ -78,10 +74,6 @@ def is_gptqmodel_available(prompt_install: bool = False): f"Found an incompatible version of gptqmodel. Found version `{version_gptqmodel}`, " f"but only versions above `{GPTQMODEL_MINIMUM_VERSION}` are supported" ) - elif prompt_install: - log.info( - "Please install GPTQModel for required functionality: `pip install -U gptqmodel --no-build-isolation -v`." - ) @lru_cache diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 96edf56f1d..3d927ab8ab 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -124,7 +124,7 @@ def dispatch_gptq( cfg = kwargs.get("gptq_quantization_config", None) - if is_gptqmodel_available(prompt_install=True): + if is_gptqmodel_available(): from gptqmodel.nn_modules.qlinear import BaseQuantLinear if isinstance(target_base_layer, BaseQuantLinear): diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 079f37477c..74a687ce5a 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -491,7 +491,8 @@ def _mixed_batch_forward( # layer output sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype) - if scaling == 1: # no scaling + # Loras such as EoRA will always be scaling == 1 so we can skip the no-op math + if scaling == 1: lora_output = lora_B(lora_A(dropout(sub_batch))) else: lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling @@ -733,7 +734,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: - if scaling == 1: # no scaling + # Loras such as EoRA will always be scaling == 1 so we can skip the no-op math + if scaling == 1: if isinstance(dropout, nn.Dropout): result = result + lora_B(lora_A(dropout(x))) else: diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index d02e161d75..591d4d738c 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -168,7 +168,7 @@ def __init__( if not hasattr(self, "peft_config"): self.peft_config = {adapter_name: peft_config} if isinstance(peft_config, PeftConfig) else peft_config else: - logger.info( + logger.warning( "Already found a `peft_config` attribute in the model. This will lead to having multiple adapters" " in the model. Make sure to know what you are doing!" ) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 3d10b322bd..e03a4f171a 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -934,7 +934,7 @@ def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None): if gptq_quantization_config is None: return None - if not is_gptqmodel_available(prompt_install=True): + if not is_gptqmodel_available(): return None from gptqmodel.utils.importer import hf_select_quant_linear From 9b778025561c28dc1ddb4aed34199382ae95cdf4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 14:31:31 +0000 Subject: [PATCH 07/15] remove useless dropout optimization + ping version to 1.9.99 (v2.0) or later --- src/peft/import_utils.py | 2 +- src/peft/tuners/lora/gptq.py | 7 +------ src/peft/tuners/lora/layer.py | 10 ++-------- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 76ca46cc91..127752ea7d 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -52,7 +52,7 @@ def is_auto_gptq_available(): @lru_cache def is_gptqmodel_available(): if importlib.util.find_spec("gptqmodel") is not None: - GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.9.0") + GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.9.99") OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99") version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel")) if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel: diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 3d927ab8ab..01f380927d 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -83,12 +83,7 @@ def forward(self, x: torch.Tensor): expected_dtype = result.dtype x = self._cast_input_dtype(x, lora_A.weight.dtype) - # lora_dropout float value is not stored so we need to check for cls - if isinstance(dropout, torch.nn.Dropout): - output = lora_B(lora_A(dropout(x))) - else: - # dropout == Identity which is no-op if lora_dropout == 0.0 - output = lora_B(lora_A(x)) + output = lora_B(lora_A(dropout(x))) if requires_conversion: output = output.to(expected_dtype) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 74a687ce5a..136eb23b14 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -736,15 +736,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: if not self.use_dora[active_adapter]: # Loras such as EoRA will always be scaling == 1 so we can skip the no-op math if scaling == 1: - if isinstance(dropout, nn.Dropout): - result = result + lora_B(lora_A(dropout(x))) - else: - result = result + lora_B(lora_A(x)) + result = result + lora_B(lora_A(dropout(x))) else: - if isinstance(dropout, nn.Dropout): - result = result + lora_B(lora_A(dropout(x))) * scaling - else: - result = result + lora_B(lora_A(x)) * scaling + result = result + lora_B(lora_A(dropout(x))) * scaling else: if isinstance(dropout, nn.Identity) or not self.training: base_result = result From fbf909f071d791b762b7562826339f2c67b36cb4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 15:00:01 +0000 Subject: [PATCH 08/15] replace lru_cache with safer keys var --- src/peft/tuners/lora/gptq.py | 9 ++------- src/peft/tuners/lora/layer.py | 10 +++------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 01f380927d..d16cfffb58 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import lru_cache from typing import Any, Optional import torch @@ -57,11 +56,6 @@ def __init__( lora_bias=lora_bias, ) - @lru_cache - def _adapter_in_lora_keys(self, adapter: str): - # only need to check lora_A as result is same for lora_B - return adapter in self.lora_A.keys() - def forward(self, x: torch.Tensor): # note: logic differs from default Linear because merging is not supported result = self.quant_linear_module(x) @@ -69,8 +63,9 @@ def forward(self, x: torch.Tensor): if self.disable_adapters: return result + lora_A_keys = self.lora_A.keys() for active_adapter in self.active_adapters: - if not self._adapter_in_lora_keys(active_adapter): + if active_adapter not in lora_A_keys: continue lora_A = self.lora_A[active_adapter] diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 136eb23b14..93e2e04d83 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -15,7 +15,6 @@ import math import warnings -from functools import lru_cache from typing import Any, Optional, Union import torch @@ -703,11 +702,6 @@ def get_delta_weight(self, adapter) -> torch.Tensor: return output_tensor - @lru_cache - def _adapter_in_lora_keys(self, adapter: str): - # only need to check lora_A as result is same for lora_B - return adapter in self.lora_A.keys() - def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) @@ -723,8 +717,10 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: else: result = self.base_layer(x, *args, **kwargs) torch_result_dtype = result.dtype + + lora_A_keys = self.lora_A.keys() for active_adapter in self.active_adapters: - if not self._adapter_in_lora_keys(active_adapter): + if active_adapter not in lora_A_keys: continue lora_A = self.lora_A[active_adapter] From 9caf45fe3e64584df3e20dcdfeab7f89246a4925 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 15:15:24 +0000 Subject: [PATCH 09/15] cleanup --- src/peft/tuners/lora/gptq.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index d16cfffb58..369f445fae 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -115,11 +115,8 @@ def dispatch_gptq( cfg = kwargs.get("gptq_quantization_config", None) if is_gptqmodel_available(): - from gptqmodel.nn_modules.qlinear import BaseQuantLinear - - if isinstance(target_base_layer, BaseQuantLinear): - new_module = GPTQLoraLinear(target, adapter_name, **kwargs) - target.qweight = target_base_layer.qweight + new_module = GPTQLoraLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight else: quant_linear = get_auto_gptq_quant_linear(cfg) From 5531dddbb0ad7cda516b562bc6226a911c96a78e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 15:18:07 +0000 Subject: [PATCH 10/15] revert --- src/peft/tuners/lora/gptq.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 369f445fae..d16cfffb58 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -115,8 +115,11 @@ def dispatch_gptq( cfg = kwargs.get("gptq_quantization_config", None) if is_gptqmodel_available(): - new_module = GPTQLoraLinear(target, adapter_name, **kwargs) - target.qweight = target_base_layer.qweight + from gptqmodel.nn_modules.qlinear import BaseQuantLinear + + if isinstance(target_base_layer, BaseQuantLinear): + new_module = GPTQLoraLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight else: quant_linear = get_auto_gptq_quant_linear(cfg) From aa9f98cef8724d37e406aa95f23a254fa7ed60da Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 28 Feb 2025 16:55:09 +0000 Subject: [PATCH 11/15] warnings.warn --- src/peft/tuners/tuners_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 591d4d738c..5d99be1f23 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -14,7 +14,6 @@ from __future__ import annotations import copy -import logging import os import re import textwrap @@ -46,9 +45,6 @@ from ._buffer_dict import BufferDict -logger = logging.getLogger(__name__) - - @contextmanager def onload_layer(layer): r""" @@ -168,7 +164,7 @@ def __init__( if not hasattr(self, "peft_config"): self.peft_config = {adapter_name: peft_config} if isinstance(peft_config, PeftConfig) else peft_config else: - logger.warning( + warnings.warn( "Already found a `peft_config` attribute in the model. This will lead to having multiple adapters" " in the model. Make sure to know what you are doing!" ) From 14b87b46dd0c6610d76743467bbab438b8e464ac Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Sat, 1 Mar 2025 07:49:48 +0000 Subject: [PATCH 12/15] add unittest Signed-off-by: ZX-ModelCloud --- tests/test_gptqmodel.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_gptqmodel.py b/tests/test_gptqmodel.py index 1eaf7f5096..f639ecb83d 100644 --- a/tests/test_gptqmodel.py +++ b/tests/test_gptqmodel.py @@ -36,6 +36,7 @@ get_peft_model, prepare_model_for_kbit_training, ) +from peft.tuners.lora import GPTQLoraLinear from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device from .testing_utils import ( @@ -347,3 +348,30 @@ def test_non_default_adapter_name(self): # sanity check assert n_trainable_default == n_trainable_other assert n_total_default == n_total_other + + @staticmethod + def test_load_lora(): + model_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit" + adapter_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit-lora" + + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") + model.load_adapter(adapter_id) + + print("peft model", model) + + # assert dynamic rank + v_proj_module = model.model.layers[5].self_attn.v_proj + assert isinstance(v_proj_module, GPTQLoraLinear) + assert v_proj_module.lora_A["default"].weight.data.shape[0] == 128 + assert v_proj_module.lora_B["default"].weight.data.shape[1] == 128 + gate_proj_module = model.model.layers[5].mlp.gate_proj + assert isinstance(gate_proj_module, GPTQLoraLinear) + assert gate_proj_module.lora_A["default"].weight.data.shape[0] == 256 + assert gate_proj_module.lora_B["default"].weight.data.shape[1] == 256 + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inp = tokenizer("Capital of France is", return_tensors="pt").to(model.device) + tokens = model.generate(**inp)[0] + result = tokenizer.decode(tokens) + + print("result: ", result) From 05ebad66f23fa43d35ef170b74e677457a1d3e23 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Mar 2025 12:30:57 +0000 Subject: [PATCH 13/15] update ci --- src/peft/import_utils.py | 4 ++-- tests/test_gptqmodel.py | 7 ++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 127752ea7d..525a11e406 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -52,8 +52,8 @@ def is_auto_gptq_available(): @lru_cache def is_gptqmodel_available(): if importlib.util.find_spec("gptqmodel") is not None: - GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.9.99") - OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99") + GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("2.0.0") + OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.24.0") version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel")) if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel: if is_optimum_available(): diff --git a/tests/test_gptqmodel.py b/tests/test_gptqmodel.py index f639ecb83d..6960df7b89 100644 --- a/tests/test_gptqmodel.py +++ b/tests/test_gptqmodel.py @@ -349,16 +349,13 @@ def test_non_default_adapter_name(self): assert n_trainable_default == n_trainable_other assert n_total_default == n_total_other - @staticmethod - def test_load_lora(): + def test_load_lora(self): model_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit" adapter_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit-lora" model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") model.load_adapter(adapter_id) - print("peft model", model) - # assert dynamic rank v_proj_module = model.model.layers[5].self_attn.v_proj assert isinstance(v_proj_module, GPTQLoraLinear) @@ -374,4 +371,4 @@ def test_load_lora(): tokens = model.generate(**inp)[0] result = tokenizer.decode(tokens) - print("result: ", result) + assert "paris" in result.lower() From 1bd6f2be6c624caf6801d10fe7286269a4c12e60 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 3 Mar 2025 22:03:09 +0000 Subject: [PATCH 14/15] fix ci adalora --- tests/test_gptqmodel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_gptqmodel.py b/tests/test_gptqmodel.py index 6960df7b89..9155d57520 100644 --- a/tests/test_gptqmodel.py +++ b/tests/test_gptqmodel.py @@ -203,10 +203,11 @@ def test_adalora_causalLM(self): model = prepare_model_for_kbit_training(model) peft_config = AdaLoraConfig( + total_step=40, init_r=6, target_r=4, - tinit=50, - tfinal=100, + tinit=10, + tfinal=20, deltaT=5, beta1=0.3, beta2=0.3, From a22f03a7e08279af0b178c5264288867ae8c2bcb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 4 Mar 2025 12:04:43 +0000 Subject: [PATCH 15/15] mark adalora test with singlegpu --- tests/test_gptqmodel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_gptqmodel.py b/tests/test_gptqmodel.py index 9155d57520..4523a87ed8 100644 --- a/tests/test_gptqmodel.py +++ b/tests/test_gptqmodel.py @@ -187,6 +187,7 @@ def test_causal_lm_training(self): # assert loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None + @pytest.mark.single_gpu_tests def test_adalora_causalLM(self): r""" Tests the gptq training with adalora