Skip to content

Commit 3a5e62f

Browse files
authored
Merge pull request axolotl-ai-cloud#41 from OpenAccess-AI-Collective/qlora-modules
attempt to find linear modules for qlora
2 parents cb5cc48 + 84196d8 commit 3a5e62f

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

src/axolotl/utils/models.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import Optional, Tuple, TYPE_CHECKING
66

7+
import bitsandbytes as bnb
78
import torch
89
import transformers
910
from torch import nn
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
334335
return model, peft_config
335336

336337

338+
def find_all_linear_names(bits, model):
339+
cls = (
340+
bnb.nn.Linear4bit
341+
if bits == 4
342+
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
343+
)
344+
lora_module_names = set()
345+
for name, module in model.named_modules():
346+
if isinstance(module, cls):
347+
names = name.split(".")
348+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
349+
350+
if "lm_head" in lora_module_names: # needed for 16-bit
351+
lora_module_names.remove("lm_head")
352+
353+
return list(lora_module_names)
354+
355+
337356
def load_lora(model, cfg):
338357
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
339358

@@ -343,12 +362,19 @@ def load_lora(model, cfg):
343362
PeftModel,
344363
)
345364

346-
lora_config = None
365+
bits = None
366+
if cfg.cfg.load_in_4bits:
367+
bits = 4
368+
elif cfg.cfg.load_in_8bits:
369+
bits = 8
370+
linear_names = find_all_linear_names(bits, model)
371+
logging.info(f"found linear modules: {repr(linear_names)}")
372+
lora_target_modules = cfg.lora_target_modules + linear_names
347373

348374
lora_config = LoraConfig(
349375
r=cfg.lora_r,
350376
lora_alpha=cfg.lora_alpha,
351-
target_modules=cfg.lora_target_modules,
377+
target_modules=lora_target_modules,
352378
lora_dropout=cfg.lora_dropout,
353379
fan_in_fan_out=cfg.lora_fan_in_fan_out,
354380
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,

0 commit comments

Comments
 (0)