Skip to content

Commit

Permalink
Support modules_to_save config option when using DeepSpeed ZeRO-3 w…
Browse files Browse the repository at this point in the history
…ith ZeRO init enabled. (huggingface#1450)

* Update other.py

* Update other.py

* fix quality

* Update other.py
  • Loading branch information
pacman100 authored and BenjaminBossan committed Mar 14, 2024
1 parent cc587e9 commit 7a68938
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import copy
import inspect
import warnings
from contextlib import nullcontext
from typing import Optional, Tuple

import accelerate
Expand Down Expand Up @@ -196,7 +197,17 @@ def weight(self):
return self.modules_to_save[self.active_adapter].weight

def update(self, adapter_name):
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
context_manager = nullcontext()
for _, param in self.original_module.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
import deepspeed

context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0)
break
with context_manager:
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))

if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
old_hook = self.modules_to_save[adapter_name]._hf_hook
Expand Down

0 comments on commit 7a68938

Please sign in to comment.