Skip to content

Commit

Permalink
FIX: error using deepspeed zero2 + load_in_8bit + lora (#874)
Browse files Browse the repository at this point in the history
Fix an issue in (Ada)LoRA forward of bnb layers when using bf16 + lora +
load_in_8bit.
  • Loading branch information
tmm1 authored Aug 31, 2023
1 parent 4338100 commit f113af0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
13 changes: 7 additions & 6 deletions src/peft/tuners/adalora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# sure.
result = result.clone()

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()

lora_A = self.lora_A[self.active_adapter]
lora_B = self.lora_B[self.active_adapter]
lora_E = self.lora_E[self.active_adapter]
dropout = self.lora_dropout[self.active_adapter]
scaling = self.scaling[self.active_adapter]
ranknum = self.ranknum[self.active_adapter] + 1e-5

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)

output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
if requires_conversion:
output = output.to(expected_dtype)
Expand Down
5 changes: 3 additions & 2 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)

output = lora_B(lora_A(dropout(x)))
if requires_conversion:
Expand Down

0 comments on commit f113af0

Please sign in to comment.