Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support modules_to_save config option when using DeepSpeed ZeRO-3 with ZeRO init enabled. #1450

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import copy
import inspect
import warnings
from contextlib import nullcontext
from typing import Optional, Tuple

import accelerate
Expand Down Expand Up @@ -194,7 +195,17 @@ def weight(self):
return self.modules_to_save[self.active_adapter].weight

def update(self, adapter_name):
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
context_manager = nullcontext()
for _, param in self.original_module.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
import deepspeed

context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0)
break
with context_manager:
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))

if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
old_hook = self.modules_to_save[adapter_name]._hf_hook
Expand Down
Loading