Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Updates for upcoming BNB Int8 release #2245

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
kwargs.update(
{
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def _create_new_module(ia3_config, adapter_name, target, **kwargs):
eightbit_kwargs.update(
{
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs):
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
Expand Down
1 change: 0 additions & 1 deletion src/peft/tuners/vera/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def _create_new_module(vera_config, vera_A, vera_B, adapter_name, target, **kwar
eightbit_kwargs.update(
{
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
"threshold": target_base_layer.state.threshold,
"index": target_base_layer.index,
}
Expand Down
14 changes: 7 additions & 7 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
if state.SCB is None:
state.SCB = weight.SCB

im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
dequantized = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
# Use bitsandbytes API if available (requires v0.45.0+)
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
else:
# Multiply by (scale/127) to dequantize.
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3

if is_cpu:
dequantized = dequantized.to(device)
return dequantized
Expand Down
17 changes: 8 additions & 9 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,8 @@ def test_8bit_merge_lora(self):
with torch.inference_mode():
out_after_merge = F.softmax(model(random_input).logits, dim=-1)

atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
assert isinstance(model, PeftModel)
Expand Down Expand Up @@ -803,8 +803,8 @@ def test_8bit_merge_and_disable_lora(self):
with torch.inference_mode():
out_after = F.softmax(model(random_input).logits, dim=-1)

atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
assert not torch.allclose(out_base, out_before, atol=atol, rtol=rtol)
assert torch.allclose(out_base, out_after, atol=atol, rtol=rtol)
assert isinstance(model, PeftModel)
Expand Down Expand Up @@ -838,8 +838,8 @@ def test_8bit_merge_lora_with_bias(self):
with torch.inference_mode():
out_after_merge = F.softmax(model(random_input).logits, dim=-1)

atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)

Expand Down Expand Up @@ -1294,9 +1294,8 @@ def test_8bit_dora_merging(self):
model = model.merge_and_unload()
out_unloaded = F.softmax(model(random_input).logits, dim=-1)

# 8bit merging less precise than 4bit
atol = 0.01
rtol = 10
atol = 1e-3
rtol = 1
# sanity check that using DoRA changes the results
assert not torch.allclose(out_base, out_dora, atol=atol, rtol=rtol)
assert torch.allclose(out_dora, out_merged, atol=atol, rtol=rtol)
Expand Down
Loading