From 58caf44c01a940dad7916d384cad24bac10f55d2 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Sun, 14 Jul 2024 14:41:15 +0800 Subject: [PATCH 1/6] fix boft mixed precision --- src/peft/tuners/boft/layer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 97a1baaa58..a186d0b061 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -27,6 +27,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function +from torch.utils.cpp_extension import load from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge @@ -77,9 +78,6 @@ def get_fbd_cuda(): if _FBD_CUDA is not None: return _FBD_CUDA - # This import initializes cuda context and should thus be local, see issue 1877 - from torch.utils.cpp_extension import load - curr_dir = os.path.dirname(__file__) # need ninja to build the extension try: @@ -594,8 +592,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: elif self.merged: result = self.base_layer(x, *args, **kwargs) else: - boft_rotation = torch.eye(self.in_features, device=x.device) - boft_scale = torch.ones((int(self.out_features), 1), device=x.device) + boft_rotation = torch.eye(self.in_features, device=x.device, dtype=previous_dtype) + boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=previous_dtype) for active_adapter in self.active_adapters: if active_adapter not in self.boft_R.keys(): @@ -615,27 +613,32 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) - - boft_P = self.boft_P.to(block_diagonal_butterfly.device) + + boft_P = self.boft_P.to(device=block_diagonal_butterfly.device, dtype=previous_dtype) + block_diagonal_butterfly = block_diagonal_butterfly.to(previous_dtype) butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1)) butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch) butterfly_oft_mat = butterfly_oft_mat_batch[0] for i in range(1, butterfly_oft_mat_batch.shape[0]): butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat - boft_rotation = butterfly_oft_mat @ boft_rotation boft_scale = boft_s * boft_scale + x = x.to(self.get_base_layer().weight.data.dtype) orig_weight = self.get_base_layer().weight.data orig_weight = torch.transpose(orig_weight, 0, 1) + boft_rotation = boft_rotation.to(previous_dtype) + orig_weight = orig_weight.to(previous_dtype) rotated_weight = torch.mm(boft_rotation, orig_weight) rotated_weight = torch.transpose(rotated_weight, 0, 1) scaled_rotated_weight = rotated_weight * boft_scale - + scaled_rotated_weight = scaled_rotated_weight.to(previous_dtype) + if self.base_layer.bias is not None: + self.base_layer.bias = self.base_layer.bias.to(previous_dtype) result = F.linear(input=x, weight=scaled_rotated_weight, bias=self.base_layer.bias) result = result.to(previous_dtype) From 8fd57a189312463d07e831aa323ab360fe652673 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 5 Aug 2024 07:33:16 +0000 Subject: [PATCH 2/6] add tests; fix conv2d --- src/peft/tuners/boft/layer.py | 20 +++++++++++++------- tests/test_common_gpu.py | 16 +++++++++++++++- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index a186d0b061..8a7fa0b826 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -77,7 +77,9 @@ def get_fbd_cuda(): if _FBD_CUDA is not None: return _FBD_CUDA - + # This import initializes cuda context and should thus be local, see issue 1877 + from torch.utils.cpp_extension import load + curr_dir = os.path.dirname(__file__) # need ninja to build the extension try: @@ -614,17 +616,18 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) - boft_P = self.boft_P.to(device=block_diagonal_butterfly.device, dtype=previous_dtype) - block_diagonal_butterfly = block_diagonal_butterfly.to(previous_dtype) + # The BOFT author's cayley_batch, dropout and FastBlockDiag ONLY return fp32 outputs. + boft_P = self.boft_P.to(x) + block_diagonal_butterfly = block_diagonal_butterfly.to(x) butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1)) butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch) butterfly_oft_mat = butterfly_oft_mat_batch[0] for i in range(1, butterfly_oft_mat_batch.shape[0]): butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat + boft_rotation = butterfly_oft_mat @ boft_rotation boft_scale = boft_s * boft_scale - x = x.to(self.get_base_layer().weight.data.dtype) @@ -910,9 +913,11 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: result = self.base_layer(x, *args, **kwargs) else: boft_rotation = torch.eye( - self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + device=x.device, + dtype=x.dtype ) - boft_scale = torch.ones((1, int(self.out_features)), device=x.device) + boft_scale = torch.ones((1, int(self.out_features)), device=x.device, dtype=x.dtype) for active_adapter in self.active_adapters: if active_adapter not in self.boft_R.keys(): @@ -933,7 +938,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) - boft_P = self.boft_P.to(block_diagonal_butterfly.device) + boft_P = self.boft_P.to(x) + block_diagonal_butterfly = block_diagonal_butterfly.to(x) butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1)) butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch) butterfly_oft_mat = butterfly_oft_mat_batch[0] diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 6039d7d850..81e3b3b555 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -50,7 +50,7 @@ ) from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.tuners.lora.config import LoraRuntimeConfig - +from peft.tuners.boft.layer import Linear, Conv2d from .testing_utils import require_bitsandbytes, require_torch_gpu, require_torch_multi_gpu @@ -1135,6 +1135,20 @@ def test_dora_ephemeral_gpu_offload(self): # The results should be the same assert torch.allclose(out_peft_model_cpu, out_peft_model_ego) + @require_torch_gpu + @pytest.mark.single_gpu_tests + def test_boft_half(self): + # Check that we can use BoFT with model loaded in half precision + layer = nn.Linear(160, 160).cuda() + layer = Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) + x = torch.randn(160, 160, device="cuda", dtype=torch.bfloat16) + x = layer(x) + + conv = nn.Conv2d(1, 1, 4).cuda() + conv = Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) + x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16) + x = conv(x) + @require_torch_gpu @require_torch_multi_gpu @pytest.mark.multi_gpu_tests From 8f618b5a937e0f0590cf6b0a7b6e8ae500b06800 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 5 Aug 2024 07:35:28 +0000 Subject: [PATCH 3/6] fix import --- src/peft/tuners/boft/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 8a7fa0b826..b27de5fa64 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -27,7 +27,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function -from torch.utils.cpp_extension import load from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge @@ -77,6 +76,7 @@ def get_fbd_cuda(): if _FBD_CUDA is not None: return _FBD_CUDA + # This import initializes cuda context and should thus be local, see issue 1877 from torch.utils.cpp_extension import load @@ -639,6 +639,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: rotated_weight = torch.transpose(rotated_weight, 0, 1) scaled_rotated_weight = rotated_weight * boft_scale + scaled_rotated_weight = scaled_rotated_weight.to(previous_dtype) if self.base_layer.bias is not None: self.base_layer.bias = self.base_layer.bias.to(previous_dtype) From 156f14649c6ca6f81c018d129d8a05605554503e Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 5 Aug 2024 12:05:09 +0000 Subject: [PATCH 4/6] fix style --- src/peft/tuners/boft/layer.py | 16 ++++++++-------- tests/test_common_gpu.py | 16 +--------------- tests/test_gpu_examples.py | 22 ++++++++++++++++++++++ 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index b27de5fa64..5a4e0208a8 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -76,10 +76,10 @@ def get_fbd_cuda(): if _FBD_CUDA is not None: return _FBD_CUDA - + # This import initializes cuda context and should thus be local, see issue 1877 from torch.utils.cpp_extension import load - + curr_dir = os.path.dirname(__file__) # need ninja to build the extension try: @@ -594,8 +594,8 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: elif self.merged: result = self.base_layer(x, *args, **kwargs) else: - boft_rotation = torch.eye(self.in_features, device=x.device, dtype=previous_dtype) - boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=previous_dtype) + boft_rotation = torch.eye(self.in_features, device=x.device, dtype=previous_dtype) + boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=previous_dtype) for active_adapter in self.active_adapters: if active_adapter not in self.boft_R.keys(): @@ -615,9 +615,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: orth_rotate_butterfly = orth_rotate_butterfly.squeeze(0) block_diagonal_butterfly = torch.block_diag(*torch.unbind(orth_rotate_butterfly)) block_diagonal_butterfly = block_diagonal_butterfly.unsqueeze(0) - + # The BOFT author's cayley_batch, dropout and FastBlockDiag ONLY return fp32 outputs. - boft_P = self.boft_P.to(x) + boft_P = self.boft_P.to(x) block_diagonal_butterfly = block_diagonal_butterfly.to(x) butterfly_oft_mat_batch = torch.bmm(block_diagonal_butterfly, boft_P.permute(0, 2, 1)) butterfly_oft_mat_batch = torch.bmm(boft_P, butterfly_oft_mat_batch) @@ -625,7 +625,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: for i in range(1, butterfly_oft_mat_batch.shape[0]): butterfly_oft_mat = butterfly_oft_mat_batch[i] @ butterfly_oft_mat - + boft_rotation = butterfly_oft_mat @ boft_rotation boft_scale = boft_s * boft_scale @@ -639,7 +639,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: rotated_weight = torch.transpose(rotated_weight, 0, 1) scaled_rotated_weight = rotated_weight * boft_scale - + scaled_rotated_weight = scaled_rotated_weight.to(previous_dtype) if self.base_layer.bias is not None: self.base_layer.bias = self.base_layer.bias.to(previous_dtype) diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 81e3b3b555..6039d7d850 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -50,7 +50,7 @@ ) from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.tuners.lora.config import LoraRuntimeConfig -from peft.tuners.boft.layer import Linear, Conv2d + from .testing_utils import require_bitsandbytes, require_torch_gpu, require_torch_multi_gpu @@ -1135,20 +1135,6 @@ def test_dora_ephemeral_gpu_offload(self): # The results should be the same assert torch.allclose(out_peft_model_cpu, out_peft_model_ego) - @require_torch_gpu - @pytest.mark.single_gpu_tests - def test_boft_half(self): - # Check that we can use BoFT with model loaded in half precision - layer = nn.Linear(160, 160).cuda() - layer = Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) - x = torch.randn(160, 160, device="cuda", dtype=torch.bfloat16) - x = layer(x) - - conv = nn.Conv2d(1, 1, 4).cuda() - conv = Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) - x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16) - x = conv(x) - @require_torch_gpu @require_torch_multi_gpu @pytest.mark.multi_gpu_tests diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index df544b606c..4bb0116bbd 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -57,6 +57,7 @@ prepare_model_for_kbit_training, replace_lora_weights_loftq, ) +from peft.tuners import boft from peft.utils import SAFETENSORS_WEIGHTS_NAME from peft.utils.loftq_utils import NFQuantizer from peft.utils.other import fsdp_auto_wrap_policy @@ -3076,3 +3077,24 @@ def test_bnb_4bit_wrap_fsdp(self): init_process_group(world_size=1, rank=0) # check that this does not raise: FSDP(model, auto_wrap_policy=fsdp_auto_wrap_policy(model), use_orig_params=False, sync_module_states=True) + +@require_torch_gpu +class TestBOFT: + """ + Test that we can correctly use half-precision models with BOFT. + """ + + @pytest.mark.single_gpu_tests + def test_boft_half_linear(self): + # Check that we can use BoFT with model loaded in half precision + layer = torch.nn.Linear(160, 160).cuda() + layer = boft.Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) + x = torch.randn(160, 160, device="cuda", dtype=torch.bfloat16) + layer(x) # does not raise + + @pytest.mark.single_gpu_tests + def test_boft_half_conv(self): + conv = torch.nn.Conv2d(1, 1, 4).cuda() + conv = boft.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) + x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16) + conv(x) # does not raise From 18db3c6895a8ac09afe0c4a9193b974feee1c749 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 6 Aug 2024 03:51:03 +0000 Subject: [PATCH 5/6] ruff 0.4.1 --- src/peft/tuners/boft/layer.py | 2 +- tests/test_gpu_examples.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 5a4e0208a8..9ed7a5255a 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -916,7 +916,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: boft_rotation = torch.eye( self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device, - dtype=x.dtype + dtype=x.dtype, ) boft_scale = torch.ones((1, int(self.out_features)), device=x.device, dtype=x.dtype) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 4bb0116bbd..4b0893be7d 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -3078,6 +3078,7 @@ def test_bnb_4bit_wrap_fsdp(self): # check that this does not raise: FSDP(model, auto_wrap_policy=fsdp_auto_wrap_policy(model), use_orig_params=False, sync_module_states=True) + @require_torch_gpu class TestBOFT: """ @@ -3090,11 +3091,11 @@ def test_boft_half_linear(self): layer = torch.nn.Linear(160, 160).cuda() layer = boft.Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) x = torch.randn(160, 160, device="cuda", dtype=torch.bfloat16) - layer(x) # does not raise + layer(x) # does not raise @pytest.mark.single_gpu_tests def test_boft_half_conv(self): conv = torch.nn.Conv2d(1, 1, 4).cuda() conv = boft.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16) x = torch.randn(1, 160, 160, device="cuda", dtype=torch.bfloat16) - conv(x) # does not raise + conv(x) # does not raise From 9f9f7bc8f672b78692e025075296033173696b48 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 6 Aug 2024 11:44:48 +0000 Subject: [PATCH 6/6] decorator --- tests/test_gpu_examples.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index aac42d9232..fe0775eef0 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -3083,12 +3083,12 @@ def test_bnb_4bit_wrap_fsdp(self): FSDP(model, auto_wrap_policy=fsdp_auto_wrap_policy(model), use_orig_params=False, sync_module_states=True) -@require_torch_gpu class TestBOFT: """ Test that we can correctly use half-precision models with BOFT. """ + @require_torch_gpu @pytest.mark.single_gpu_tests def test_boft_half_linear(self): # Check that we can use BoFT with model loaded in half precision @@ -3097,6 +3097,7 @@ def test_boft_half_linear(self): x = torch.randn(160, 160, device="cuda", dtype=torch.bfloat16) layer(x) # does not raise + @require_torch_gpu @pytest.mark.single_gpu_tests def test_boft_half_conv(self): conv = torch.nn.Conv2d(1, 1, 4).cuda()