Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX GPTQModel Lora Wrapper #2404

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def is_auto_gptq_available():
@lru_cache
def is_gptqmodel_available():
if importlib.util.find_spec("gptqmodel") is not None:
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.7.0")
GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.9.99")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently problematic, as there is no gptqmodel release that satisfies this version requirement. Therefore, is_gptqmodel_available will always raise an error. Since is_gptqmodel_available is called by require_auto_gptq, even this check will always fail, meaning that our GPU tests cannot run at all.

I think require_auto_gptq should be adjusted to not fail if the installed gptqmodel version is too low.

OPTIMUM_MINIMUM_VERSION = packaging.version.parse("1.23.99")
version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel"))
if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel:
Expand All @@ -62,17 +62,17 @@ def is_gptqmodel_available():
return True
else:
raise ImportError(
f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher. Found version {version_optimum}, "
f"but only versions above {OPTIMUM_MINIMUM_VERSION} are supported"
f"gptqmodel requires optimum version `{OPTIMUM_MINIMUM_VERSION}` or higher. Found version `{version_optimum}`, "
f"but only versions above `{OPTIMUM_MINIMUM_VERSION}` are supported"
)
else:
raise ImportError(
f"gptqmodel requires optimum version {OPTIMUM_MINIMUM_VERSION} or higher to be installed."
f"gptqmodel requires optimum version `{OPTIMUM_MINIMUM_VERSION}` or higher to be installed."
)
else:
raise ImportError(
f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, "
f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported"
f"Found an incompatible version of gptqmodel. Found version `{version_gptqmodel}`, "
f"but only versions above `{GPTQMODEL_MINIMUM_VERSION}` are supported"
)


Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .config import EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig
from .eva import get_eva_state_dict, initialize_lora_eva_weights
from .gptq import QuantLinear
from .gptq import GPTQLoraLinear
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer
from .model import LoraModel

Expand All @@ -27,13 +27,13 @@
"Conv3d",
"Embedding",
"EvaConfig",
"GPTQLoraLinear",
"Linear",
"LoftQConfig",
"LoraConfig",
"LoraLayer",
"LoraModel",
"LoraRuntimeConfig",
"QuantLinear",
"get_eva_state_dict",
"initialize_lora_eva_weights",
]
Expand Down
28 changes: 18 additions & 10 deletions src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import torch

from peft.import_utils import is_gptqmodel_available
from peft.tuners.lora.layer import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear
from peft.utils import get_auto_gptq_quant_linear


class QuantLinear(torch.nn.Module, LoraLayer):
class GPTQLoraLinear(torch.nn.Module, LoraLayer):
def __init__(
self,
base_layer,
Expand Down Expand Up @@ -64,9 +63,11 @@ def forward(self, x: torch.Tensor):
if self.disable_adapters:
return result

lora_A_keys = self.lora_A.keys()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
if active_adapter not in lora_A_keys:
continue

lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
Expand All @@ -78,9 +79,13 @@ def forward(self, x: torch.Tensor):
x = self._cast_input_dtype(x, lora_A.weight.dtype)

output = lora_B(lora_A(dropout(x)))

if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling

if scaling != 1: # skip scaling == 1 no-op
output = output * scaling

result += output
return result

Expand Down Expand Up @@ -110,13 +115,16 @@ def dispatch_gptq(
cfg = kwargs.get("gptq_quantization_config", None)

if is_gptqmodel_available():
device_map = kwargs.get("device_map", None)
quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map)
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

if isinstance(target_base_layer, BaseQuantLinear):
new_module = GPTQLoraLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight
else:
quant_linear = get_auto_gptq_quant_linear(cfg)

if quant_linear is not None and isinstance(target_base_layer, quant_linear):
new_module = QuantLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight
if quant_linear is not None and isinstance(target_base_layer, quant_linear):
new_module = GPTQLoraLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

return new_module
19 changes: 16 additions & 3 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,13 @@ def _mixed_batch_forward(
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling

# Loras such as EoRA will always be scaling == 1 so we can skip the no-op math
if scaling == 1:
lora_output = lora_B(lora_A(dropout(sub_batch)))
else:
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling

result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)

return result
Expand Down Expand Up @@ -711,17 +717,24 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype

lora_A_keys = self.lora_A.keys()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
if active_adapter not in lora_A_keys:
continue

lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = self._cast_input_dtype(x, lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
# Loras such as EoRA will always be scaling == 1 so we can skip the no-op math
if scaling == 1:
result = result + lora_B(lora_A(dropout(x)))
else:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
if isinstance(dropout, nn.Identity) or not self.training:
base_result = result
Expand Down
6 changes: 1 addition & 5 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import annotations

import copy
import logging
import os
import re
import textwrap
Expand Down Expand Up @@ -46,9 +45,6 @@
from ._buffer_dict import BufferDict


logger = logging.getLogger(__name__)


@contextmanager
def onload_layer(layer):
r"""
Expand Down Expand Up @@ -168,7 +164,7 @@ def __init__(
if not hasattr(self, "peft_config"):
self.peft_config = {adapter_name: peft_config} if isinstance(peft_config, PeftConfig) else peft_config
else:
logger.info(
warnings.warn(
"Already found a `peft_config` attribute in the model. This will lead to having multiple adapters"
" in the model. Make sure to know what you are doing!"
)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_gptqmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_peft_model,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import GPTQLoraLinear
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device

from .testing_utils import (
Expand Down Expand Up @@ -347,3 +348,30 @@ def test_non_default_adapter_name(self):
# sanity check
assert n_trainable_default == n_trainable_other
assert n_total_default == n_total_other

@staticmethod
def test_load_lora():
model_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a smaller model that could be used here? That would reduce the risk of getting a network timeout or full disk error on CI. If not, we can try how this one works out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything below 1B have massive quantization errors and would cause inference to be highly unstable which cause ci tests to be wildly unstable too. Let's stick with 1B unless we get errors.

adapter_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit-lora"

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model.load_adapter(adapter_id)

print("peft model", model)

# assert dynamic rank
v_proj_module = model.model.layers[5].self_attn.v_proj
assert isinstance(v_proj_module, GPTQLoraLinear)
assert v_proj_module.lora_A["default"].weight.data.shape[0] == 128
assert v_proj_module.lora_B["default"].weight.data.shape[1] == 128
gate_proj_module = model.model.layers[5].mlp.gate_proj
assert isinstance(gate_proj_module, GPTQLoraLinear)
assert gate_proj_module.lora_A["default"].weight.data.shape[0] == 256
assert gate_proj_module.lora_B["default"].weight.data.shape[1] == 256

tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer("Capital of France is", return_tensors="pt").to(model.device)
tokens = model.generate(**inp)[0]
result = tokenizer.decode(tokens)

print("result: ", result)