4
4
from pathlib import Path
5
5
from typing import Optional , Tuple , TYPE_CHECKING
6
6
7
+ import bitsandbytes as bnb
7
8
import torch
8
9
import transformers
9
10
from torch import nn
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
334
335
return model , peft_config
335
336
336
337
338
+ def find_all_linear_names (bits , model ):
339
+ cls = (
340
+ bnb .nn .Linear4bit
341
+ if bits == 4
342
+ else (bnb .nn .Linear8bitLt if bits == 8 else torch .nn .Linear )
343
+ )
344
+ lora_module_names = set ()
345
+ for name , module in model .named_modules ():
346
+ if isinstance (module , cls ):
347
+ names = name .split ("." )
348
+ lora_module_names .add (names [0 ] if len (names ) == 1 else names [- 1 ])
349
+
350
+ if "lm_head" in lora_module_names : # needed for 16-bit
351
+ lora_module_names .remove ("lm_head" )
352
+
353
+ return list (lora_module_names )
354
+
355
+
337
356
def load_lora (model , cfg ):
338
357
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
339
358
@@ -343,12 +362,19 @@ def load_lora(model, cfg):
343
362
PeftModel ,
344
363
)
345
364
346
- lora_config = None
365
+ bits = None
366
+ if cfg .cfg .load_in_4bits :
367
+ bits = 4
368
+ elif cfg .cfg .load_in_8bits :
369
+ bits = 8
370
+ linear_names = find_all_linear_names (bits , model )
371
+ logging .info (f"found linear modules: { repr (linear_names )} " )
372
+ lora_target_modules = cfg .lora_target_modules + linear_names
347
373
348
374
lora_config = LoraConfig (
349
375
r = cfg .lora_r ,
350
376
lora_alpha = cfg .lora_alpha ,
351
- target_modules = cfg . lora_target_modules ,
377
+ target_modules = lora_target_modules ,
352
378
lora_dropout = cfg .lora_dropout ,
353
379
fan_in_fan_out = cfg .lora_fan_in_fan_out ,
354
380
modules_to_save = cfg .lora_modules_to_save if cfg .lora_modules_to_save else None ,
0 commit comments