Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix] Fix BOFT mixed precision #1925

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/peft/tuners/boft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
boft_rotation = torch.eye(self.in_features, device=x.device)
boft_scale = torch.ones((int(self.out_features), 1), device=x.device)
boft_rotation = torch.eye(self.in_features, device=x.device, dtype=previous_dtype)
boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=previous_dtype)

for active_adapter in self.active_adapters:
if active_adapter not in self.boft_R.keys():
Expand All @@ -616,7 +616,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)

boft_P = self.boft_P.to(block_diagonal_butterfly.device)
# The BOFT author's cayley_batch, dropout and FastBlockDiag ONLY return fp32 outputs.
boft_P = self.boft_P.to(x)
block_diagonal_butterfly = block_diagonal_butterfly.to(x)
butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
butterfly_oft_mat = butterfly_oft_mat_batch[0]
Expand All @@ -631,11 +633,16 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:

orig_weight = self.get_base_layer().weight.data
orig_weight = torch.transpose(orig_weight, 0, 1)
boft_rotation = boft_rotation.to(previous_dtype)
orig_weight = orig_weight.to(previous_dtype)
rotated_weight = torch.mm(boft_rotation, orig_weight)
rotated_weight = torch.transpose(rotated_weight, 0, 1)

scaled_rotated_weight = rotated_weight * boft_scale

scaled_rotated_weight = scaled_rotated_weight.to(previous_dtype)
if self.base_layer.bias is not None:
self.base_layer.bias = self.base_layer.bias.to(previous_dtype)
result = F.linear(input=x, weight=scaled_rotated_weight, bias=self.base_layer.bias)

result = result.to(previous_dtype)
Expand Down Expand Up @@ -907,9 +914,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
result = self.base_layer(x, *args, **kwargs)
else:
boft_rotation = torch.eye(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
device=x.device,
dtype=x.dtype,
)
boft_scale = torch.ones((1, int(self.out_features)), device=x.device)
boft_scale = torch.ones((1, int(self.out_features)), device=x.device, dtype=x.dtype)

for active_adapter in self.active_adapters:
if active_adapter not in self.boft_R.keys():
Expand All @@ -930,7 +939,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly))
block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0)

boft_P = self.boft_P.to(block_diagonal_butterfly.device)
boft_P = self.boft_P.to(x)
block_diagonal_butterfly = block_diagonal_butterfly.to(x)
butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1))
butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch)
butterfly_oft_mat = butterfly_oft_mat_batch[0]
Expand Down
24 changes: 24 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
prepare_model_for_kbit_training,
replace_lora_weights_loftq,
)
from peft.tuners import boft
from peft.utils import SAFETENSORS_WEIGHTS_NAME
from peft.utils.loftq_utils import NFQuantizer
from peft.utils.other import fsdp_auto_wrap_policy
Expand Down Expand Up @@ -3080,3 +3081,26 @@ def test_bnb_4bit_wrap_fsdp(self):
init_process_group(world_size=1, rank=0)
# check that this does not raise:
FSDP(model, auto_wrap_policy=fsdp_auto_wrap_policy(model), use_orig_params=False, sync_module_states=True)


class TestBOFT:
"""
Test that we can correctly use half-precision models with BOFT.
"""

@require_torch_gpu
@pytest.mark.single_gpu_tests
def test_boft_half_linear(self):
# Check that we can use BoFT with model loaded in half precision
layer = torch.nn.Linear(160, 160).cuda()
layer = boft.Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
x = torch.randn(160, 160, device="cuda", dtype=torch.bfloat16)
layer(x) # does not raise

@require_torch_gpu
@pytest.mark.single_gpu_tests
def test_boft_half_conv(self):
conv = torch.nn.Conv2d(1, 1, 4).cuda()
conv = boft.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16)
conv(x) # does not raise
Loading