1
+ import re
2
+ from typing import Dict , Set
3
+
4
+ import torch
1
5
import torch .nn as nn
2
- from peft import PeftModel
6
+ from peft import PeftModel , PeftType
7
+
8
+
9
+ def extract_lora_layers (model : PeftModel , names : Set [str ], adapter_name : str = "default" ):
10
+ config = model .peft_config [adapter_name ]
11
+ if config .peft_type != PeftType .LORA :
12
+ raise ValueError (f"Adapter { adapter_name } is not a LORA adapter." )
13
+ # to_return = lora_state_dict(model, bias=model.peft_config.bias)
14
+ # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
15
+ # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
16
+ bias = config .bias
17
+ if bias == "none" :
18
+ to_return = {k for k in names if "lora_" in k }
19
+ elif bias == "all" :
20
+ to_return = {k for k in names if "lora_" in k or "bias" in k }
21
+ elif bias == "lora_only" :
22
+ to_return = set ()
23
+ for k in names :
24
+ if "lora_" in k :
25
+ to_return .add (k )
26
+ bias_name = k .split ("lora_" )[0 ] + "bias"
27
+ if bias_name in names :
28
+ to_return .add (bias_name )
29
+ else :
30
+ raise NotImplementedError
31
+ to_return = {k for k in to_return if (("lora_" in k and adapter_name in k ) or ("bias" in k ))}
32
+ if config .use_dora :
33
+ # Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
34
+ # ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
35
+ # we want the state_dict format not to change, we remove the "weight" part.
36
+ new_dora_suffix = f"lora_magnitude_vector.{ adapter_name } .weight"
37
+
38
+ def renamed_dora_weights (k ):
39
+ if k .endswith (new_dora_suffix ):
40
+ k = k [:- 7 ] # remove ".weight"
41
+ return k
42
+
43
+ to_return = {renamed_dora_weights (k ) for k in to_return }
44
+
45
+ to_return = {re .sub (f"lora_\S\.{ adapter_name } \.(weight|bias)" , "base_layer" , k ) for k in to_return }
46
+ return to_return
47
+
48
+
49
+ class PeftUnwrapMixin :
50
+ def __init__ (self , peft_model : PeftModel ):
51
+ self .base_model = peft_model .get_base_model ()
52
+ # peft does not affect buffers
53
+ self .lora_layers = extract_lora_layers (peft_model , set (n for n , p in self .base_model .named_parameters ()))
54
+ potential_lora_weights = set ()
55
+ for n in self .lora_layers :
56
+ potential_lora_weights .add (f"{ n } .weight" )
57
+ potential_lora_weights .add (f"{ n } .bias" )
58
+ self .lora_param_to_origin_param = {n : n .replace ("base_layer." , "" ) for n in potential_lora_weights }
59
+ self .origin_param_to_lora_param = {v : k for k , v in self .lora_param_to_origin_param .items ()}
60
+
61
+ def named_parameters (self ):
62
+ for n , p in self .base_model .named_parameters ():
63
+ if n in self .lora_param_to_origin_param :
64
+ n = self .lora_param_to_origin_param [n ]
65
+ yield n , p
66
+
67
+ def named_buffers (self ):
68
+ return self .base_model .named_buffers ()
69
+
70
+ @property
71
+ def _modules (self ):
72
+ return self .base_model ._modules
73
+
74
+ @property
75
+ def _non_persistent_buffers_set (self ):
76
+ return self .base_model ._non_persistent_buffers_set
77
+
78
+ def patch_state_dict (self , state_dict : Dict [str , torch .Tensor ]):
79
+ new_state_dict = {}
80
+ for k , v in state_dict .items ():
81
+ if k in self .origin_param_to_lora_param :
82
+ k = self .origin_param_to_lora_param [k ]
83
+ new_state_dict [k ] = v
84
+ return new_state_dict
85
+
86
+ def state_dict (self ):
87
+ state_dict = {}
88
+ for k , v in self .base_model .state_dict ().items ():
89
+ if k in self .lora_param_to_origin_param :
90
+ k = self .lora_param_to_origin_param [k ]
91
+ state_dict [k ] = v
92
+ return state_dict
93
+
94
+ def load_state_dict (self , state_dict , strict : bool = True , assign : bool = False ):
95
+ state_dict = self .patch_state_dict (state_dict )
96
+ self .base_model .load_state_dict (state_dict , strict = strict , assign = assign )
97
+
98
+ def __hash__ (self ):
99
+ return hash (self .base_model )
3
100
4
101
5
102
class ModelWrapper (nn .Module ):
@@ -23,7 +120,7 @@ def unwrap(self, unwrap_peft: bool = True):
23
120
else :
24
121
model = self .module
25
122
if unwrap_peft and isinstance (model , PeftModel ):
26
- model = model . get_base_model ( )
123
+ model = PeftUnwrapMixin ( model )
27
124
return model
28
125
29
126
def forward (self , * args , ** kwargs ):
0 commit comments