Skip to content

Commit

Permalink
FIX low_cpu_mem_usage=True with 8bit bitsandbytes (huggingface#2325)
Browse files Browse the repository at this point in the history
There was a bug in PEFT that occurred when trying to use the
low_cpu_mem_usage=True option with 8bit bitsandbytes quantized models.
This bug is fixed now.
  • Loading branch information
BenjaminBossan authored Jan 14, 2025
1 parent 1e8bc60 commit 3289134
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
7 changes: 0 additions & 7 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,6 @@ def _replace_module(self, parent, child_name, new_module, child):
if hasattr(child, "base_layer"):
child = child.base_layer

if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)

meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
Expand Down
34 changes: 32 additions & 2 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3870,15 +3870,17 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):
class TestLowCpuMemUsageDifferentDevices:
"""Test for the low CPU memory usage option for loading PEFT models.
There are already tests for this in test_initialization.py but here we want to specifically test diverging devices
for the model and state_dict.
There are already tests for low_cpu_mem_usage=True in test_initialization.py but here we want to run tests that
require a GPU.
"""

model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
device = infer_device()

@pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")])
def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
# specifically test diverging devices for the model and state_dict
inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
inputs = {k: v.to(device_model) for k, v in inputs.items()}

Expand Down Expand Up @@ -3912,6 +3914,34 @@ def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, devi
assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
assert {p.device.type for p in model.parameters()} == {device_model}

@pytest.mark.parametrize("quantization_method", ["bnb-4bit", "bnb-8bit"])
def test_low_cpu_mem_usage_with_quantization(self, quantization_method):
# Ensure that low_cpu_mem_usage works with quantization
# See also https://github.com/huggingface/diffusers/issues/10550
if quantization_method == "bnb-4bit":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float32,
bnb_4bit_quant_storage=torch.float32,
bnb_4bit_use_double_quant=True,
)
elif quantization_method == "bnb-8bit":
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
raise ValueError(f"Unknown quantization method {quantization_method}")

model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config)
if model.device.type != self.device:
# calling model.to("cuda") with 8 bit bnb raises an error, thus guard against it
model = model.to(self.device)

lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")

# We use get_peft_model with low_cpu_mem_usage=True here. This is not typically done in practice (the option is
# mostly interesting for loading trained adapters), but it does the job for testing purposes.
model = get_peft_model(model, lora_config, low_cpu_mem_usage=True) # this should not raise
assert {p.device.type for p in model.parameters()} == {self.device, "meta"}


class TestEvaInitializationGPU:
"""GPU tests for the Eva initialization method."""
Expand Down

0 comments on commit 3289134

Please sign in to comment.