Skip to content

Commit

Permalink
fix boft mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jul 14, 2024
1 parent e72a96f commit 58caf44
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/peft/tuners/boft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load

from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge

Expand Down Expand Up @@ -77,9 +78,6 @@ def get_fbd_cuda():
if _FBD_CUDA is not None:
return _FBD_CUDA

# This import initializes cuda context and should thus be local, see issue 1877
from torch.utils.cpp_extension import load

curr_dir = os.path.dirname(__file__)
# need ninja to build the extension
try:
Expand Down Expand Up @@ -594,8 +592,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
boft_rotation = torch.eye(self.in_features, device=x.device)
boft_scale = torch.ones((int(self.out_features), 1), device=x.device)
boft_rotation = torch.eye(self.in_features, device=x.device, dtype=previous_dtype)
boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=previous_dtype)

for active_adapter in self.active_adapters:
if active_adapter not in self.boft_R.keys():
Expand All @@ -615,27 +613,32 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0)
block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)

boft_P = self.boft_P.to(block_diagonal_butterfly.device)

boft_P = self.boft_P.to(device=block_diagonal_butterfly.device, dtype=previous_dtype)
block_diagonal_butterfly = block_diagonal_butterfly.to(previous_dtype)
butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
butterfly_oft_mat = butterfly_oft_mat_batch[0]

for i in range(1, butterfly_oft_mat_batch.shape[0]):
butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat

boft_rotation = butterfly_oft_mat @ boft_rotation
boft_scale = boft_s * boft_scale


x = x.to(self.get_base_layer().weight.data.dtype)

orig_weight = self.get_base_layer().weight.data
orig_weight = torch.transpose(orig_weight, 0, 1)
boft_rotation = boft_rotation.to(previous_dtype)
orig_weight = orig_weight.to(previous_dtype)
rotated_weight = torch.mm(boft_rotation, orig_weight)
rotated_weight = torch.transpose(rotated_weight, 0, 1)

scaled_rotated_weight = rotated_weight * boft_scale

scaled_rotated_weight = scaled_rotated_weight.to(previous_dtype)
if self.base_layer.bias is not None:
self.base_layer.bias = self.base_layer.bias.to(previous_dtype)
result = F.linear(input=x, weight=scaled_rotated_weight, bias=self.base_layer.bias)

result = result.to(previous_dtype)
Expand Down

0 comments on commit 58caf44

Please sign in to comment.