Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Don't eagerly import bnb for LoftQ #1683

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/peft/utils/loftq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
from peft.import_utils import is_bnb_4bit_available, is_bnb_available


if is_bnb_available():
import bitsandbytes as bnb


class NFQuantizer:
def __init__(self, num_bits=2, device="cuda", method="normal", block_size=64, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -192,6 +188,11 @@ def _low_rank_decomposition(weight, reduced_rank=32):

@torch.no_grad()
def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, reduced_rank: int, num_iter=1):
if is_bnb_available():
import bitsandbytes as bnb
else:
raise ValueError("bitsandbytes is not available, please install it to use LoftQ.")

if num_bits not in [2, 4, 8]:
raise ValueError("Only support 2, 4, 8 bits quantization")
if num_iter <= 0:
Expand Down Expand Up @@ -239,6 +240,8 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r

@torch.no_grad()
def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int):
import bitsandbytes as bnb

if num_bits != 4:
raise ValueError("Only 4 bit quantization supported at the moment.")
if not is_bnb_4bit_available():
Expand Down
Loading