| 
14 | 14 | 
 
  | 
15 | 15 | import warnings  | 
16 | 16 | 
 
  | 
 | 17 | +import packaging  | 
17 | 18 | import torch  | 
 | 19 | +import transformers  | 
18 | 20 | from transformers.pytorch_utils import Conv1D  | 
19 | 21 | 
 
  | 
20 | 22 | from peft.import_utils import is_bnb_4bit_available, is_bnb_available  | 
 | 
30 | 32 | 
 
  | 
31 | 33 | from .gptq import SVDQuantLinear  | 
32 | 34 | from .layer import AdaLoraLayer, RankAllocator, SVDLinear  | 
33 |  | -import transformers  | 
34 |  | -import packaging  | 
 | 35 | + | 
35 | 36 | 
 
  | 
36 | 37 | if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"):  | 
37 |  | -    from transformers.integrations import is_deepspeed_zero3_enabled, deepspeed_config  | 
 | 38 | +    from transformers.integrations import is_deepspeed_zero3_enabled  | 
38 | 39 | else:  | 
39 |  | -    from transformers.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config  | 
40 |  | - | 
 | 40 | +    from transformers.deepspeed import is_deepspeed_zero3_enabled  | 
41 | 41 | 
 
  | 
42 | 42 | 
 
  | 
43 | 43 | class AdaLoraModel(LoraModel):  | 
@@ -253,9 +253,13 @@ def forward(self, *args, **kwargs):  | 
253 | 253 |             for n, p in self.model.named_parameters():  | 
254 | 254 |                 if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:  | 
255 | 255 |                     if is_deepspeed_zero3_enabled():  | 
256 |  | -                        import deepspeed  | 
257 | 256 |                         import contextlib  | 
258 |  | -                        with deepspeed.zero.GatheredParameters(p, modifier_rank=0, fwd_module=self) if p.shape==torch.Size([0]) else contextlib.nullcontext() :  | 
 | 257 | + | 
 | 258 | +                        import deepspeed  | 
 | 259 | + | 
 | 260 | +                        with deepspeed.zero.GatheredParameters(  | 
 | 261 | +                            p, modifier_rank=0, fwd_module=self  | 
 | 262 | +                        ) if p.shape == torch.Size([0]) else contextlib.nullcontext():  | 
259 | 263 |                             para_cov = p @ p.T if "lora_A" in n else p.T @ p  | 
260 | 264 |                     else:  | 
261 | 265 |                         para_cov = p @ p.T if "lora_A" in n else p.T @ p  | 
 | 
0 commit comments