|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 |
| - |
16 | 15 | import copy
|
| 16 | +import inspect |
17 | 17 | import os
|
18 | 18 | import warnings
|
19 | 19 |
|
| 20 | +import accelerate |
20 | 21 | import torch
|
| 22 | +from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
21 | 23 |
|
22 | 24 |
|
23 | 25 | # Add or edit model card to have `library_name: peft`
|
@@ -140,6 +142,26 @@ def __init__(self, module_to_save, adapter_name):
|
140 | 142 | def update(self, adapter_name):
|
141 | 143 | self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
|
142 | 144 |
|
| 145 | + if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): |
| 146 | + old_hook = self.modules_to_save[adapter_name]._hf_hook |
| 147 | + new_hook = self._create_new_hook(old_hook) |
| 148 | + remove_hook_from_module(self.modules_to_save[adapter_name]) |
| 149 | + add_hook_to_module(self.modules_to_save[adapter_name], new_hook) |
| 150 | + |
| 151 | + def _create_new_hook(self, old_hook): |
| 152 | + r""" |
| 153 | + Creates a new hook based on the old hook. Use it only if you know what you are doing ! |
| 154 | + """ |
| 155 | + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) |
| 156 | + old_hook_attr = old_hook.__dict__ |
| 157 | + filtered_old_hook_attr = {} |
| 158 | + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) |
| 159 | + for k in old_hook_attr.keys(): |
| 160 | + if k in old_hook_init_signature.parameters: |
| 161 | + filtered_old_hook_attr[k] = old_hook_attr[k] |
| 162 | + new_hook = old_hook_cls(**filtered_old_hook_attr) |
| 163 | + return new_hook |
| 164 | + |
143 | 165 | def forward(self, *args, **kwargs):
|
144 | 166 | if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
|
145 | 167 | return self.original_module(*args, **kwargs)
|
|
0 commit comments