Skip to content

Commit d297a96

Browse files
committed
add deepspeed support for adalora finetune
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
1 parent 9d67d12 commit d297a96

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/peft/tuners/adalora/layer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,21 @@
1515
import warnings
1616
from typing import Any, List, Optional
1717

18+
import packaging
1819
import torch
20+
import transformers
1921
from torch import nn
2022

2123
from peft.tuners.lora import LoraLayer
2224
from peft.tuners.tuners_utils import check_adapters_to_merge
2325
from peft.utils import transpose
24-
import transformers
25-
import packaging
26+
2627

2728
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"):
28-
from transformers.integrations import is_deepspeed_zero3_enabled, deepspeed_config
29+
from transformers.integrations import deepspeed_config
2930
else:
30-
from transformers.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config
31+
from transformers.deepspeed import deepspeed_config
32+
3133

3234
class AdaLoraLayer(LoraLayer):
3335
# List all names of layers that may contain adapter weights
@@ -262,6 +264,7 @@ def update_ipt(self, model):
262264
with torch.no_grad():
263265
if deepspeed_config() is not None:
264266
import deepspeed
267+
265268
grad = deepspeed.utils.safe_get_full_grad(p)
266269
self.ipt[n] = (p * grad).abs().detach()
267270
else:

src/peft/tuners/adalora/model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
import warnings
1616

17+
import packaging
1718
import torch
19+
import transformers
1820
from transformers.pytorch_utils import Conv1D
1921

2022
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
@@ -30,14 +32,12 @@
3032

3133
from .gptq import SVDQuantLinear
3234
from .layer import AdaLoraLayer, RankAllocator, SVDLinear
33-
import transformers
34-
import packaging
35+
3536

3637
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"):
37-
from transformers.integrations import is_deepspeed_zero3_enabled, deepspeed_config
38+
from transformers.integrations import is_deepspeed_zero3_enabled
3839
else:
39-
from transformers.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config
40-
40+
from transformers.deepspeed import is_deepspeed_zero3_enabled
4141

4242

4343
class AdaLoraModel(LoraModel):
@@ -253,9 +253,13 @@ def forward(self, *args, **kwargs):
253253
for n, p in self.model.named_parameters():
254254
if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:
255255
if is_deepspeed_zero3_enabled():
256-
import deepspeed
257256
import contextlib
258-
with deepspeed.zero.GatheredParameters(p, modifier_rank=0, fwd_module=self) if p.shape==torch.Size([0]) else contextlib.nullcontext() :
257+
258+
import deepspeed
259+
260+
with deepspeed.zero.GatheredParameters(
261+
p, modifier_rank=0, fwd_module=self
262+
) if p.shape == torch.Size([0]) else contextlib.nullcontext():
259263
para_cov = p @ p.T if "lora_A" in n else p.T @ p
260264
else:
261265
para_cov = p @ p.T if "lora_A" in n else p.T @ p

0 commit comments

Comments
 (0)