Skip to content

Commit e27e883

Browse files
[ModulesToSave] add correct hook management for modules to save (#755)
* add correct hook management for modules to save * forward contrib credits from finding the solution * add nice GPU tests * quality --------- Co-authored-by: BenjaminBossan <BenjaminBossan@users.noreply.github.com>
1 parent ffbb6bc commit e27e883

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

src/peft/utils/other.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
import copy
16+
import inspect
1717
import os
1818
import warnings
1919

20+
import accelerate
2021
import torch
22+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
2123

2224

2325
# Add or edit model card to have `library_name: peft`
@@ -140,6 +142,26 @@ def __init__(self, module_to_save, adapter_name):
140142
def update(self, adapter_name):
141143
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
142144

145+
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
146+
old_hook = self.modules_to_save[adapter_name]._hf_hook
147+
new_hook = self._create_new_hook(old_hook)
148+
remove_hook_from_module(self.modules_to_save[adapter_name])
149+
add_hook_to_module(self.modules_to_save[adapter_name], new_hook)
150+
151+
def _create_new_hook(self, old_hook):
152+
r"""
153+
Creates a new hook based on the old hook. Use it only if you know what you are doing !
154+
"""
155+
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
156+
old_hook_attr = old_hook.__dict__
157+
filtered_old_hook_attr = {}
158+
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
159+
for k in old_hook_attr.keys():
160+
if k in old_hook_init_signature.parameters:
161+
filtered_old_hook_attr[k] = old_hook_attr[k]
162+
new_hook = old_hook_cls(**filtered_old_hook_attr)
163+
return new_hook
164+
143165
def forward(self, *args, **kwargs):
144166
if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
145167
return self.original_module(*args, **kwargs)

tests/test_common_gpu.py

+38
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import (
2121
AutoModelForCausalLM,
2222
AutoModelForSeq2SeqLM,
23+
AutoModelForSequenceClassification,
2324
AutoTokenizer,
2425
BitsAndBytesConfig,
2526
LlamaForCausalLM,
@@ -316,3 +317,40 @@ def test_print_4bit_expected(self):
316317

317318
self.assertEqual(trainable_params, EXPECTED_TRAINABLE_PARAMS)
318319
self.assertEqual(all_params, EXPECTED_ALL_PARAMS)
320+
321+
@require_torch_gpu
322+
@pytest.mark.single_gpu_tests
323+
@require_bitsandbytes
324+
def test_modules_to_save_grad(self):
325+
model_id = "bigscience/bloomz-560m"
326+
load_in_4bit = True
327+
328+
model = AutoModelForSequenceClassification.from_pretrained(
329+
model_id,
330+
load_in_4bit=load_in_4bit,
331+
torch_dtype=torch.float32,
332+
)
333+
334+
model = prepare_model_for_kbit_training(model)
335+
336+
config = LoraConfig(
337+
r=16,
338+
lora_alpha=16,
339+
lora_dropout=0.05,
340+
bias="none",
341+
task_type="SEQ_CLS",
342+
)
343+
344+
peft_model = get_peft_model(model, config)
345+
346+
lm_head = peft_model.base_model.model.score
347+
original_module = lm_head.original_module
348+
modules_to_save = lm_head.modules_to_save.default
349+
350+
inputs = torch.randn((1024))
351+
o1 = lm_head(inputs)
352+
o1.mean().backward()
353+
354+
self.assertTrue(modules_to_save.weight.requires_grad is True)
355+
self.assertTrue(original_module.weight.grad is None)
356+
self.assertTrue(modules_to_save.weight.grad is not None)

0 commit comments

Comments
 (0)