Skip to content

Commit

Permalink
FIX huggingface#2295: Warn when user reloads modified model
Browse files Browse the repository at this point in the history
When modifying a model with `get_peft_model` that was already modified
in the same way, even specifying a different config may not change
the trainable parameter count, e.g. when specifying target modules that
are only a subset of the previous target modules.

With this patch a warning will be issued with a hint to `.unload()`
when calling `get_peft_model` on an already modified model.
  • Loading branch information
nemo committed Jan 6, 2025
1 parent 6a533b7 commit f32e517
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
get_layer_status,
)
from .tuners import (
AdaLoraConfig,
Expand Down Expand Up @@ -181,6 +182,21 @@ def get_peft_model(
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name

# Especially in notebook environments there could be a case that a user
# wants to experiment with different configuration values. However, it
# is likely that there won't be any changes for new configs on an already
# initialized PEFT model. The best we can do is warn the user about it.
try:
if len(get_layer_status(model)) > 0:
warnings.warn(
"You are trying to modify a model with PEFT for a "
"second time. If you want to reload the model with a "
"different config, make sure to call `.unload()` before."
)
except ValueError:
# not a PEFT model or no adapters in use
pass

if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "
Expand Down
39 changes: 39 additions & 0 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch


class TestGetPeftModel:
RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*"

@pytest.fixture
def get_peft_model(self):
from peft import get_peft_model

return get_peft_model

@pytest.fixture
def lora_config(self):
from peft import LoraConfig

return LoraConfig(target_modules="0")

@pytest.fixture
def base_model(self):
return torch.nn.Sequential(torch.nn.Linear(10, 2))

def test_get_peft_model_warns_when_reloading_model(self, get_peft_model, lora_config, base_model):
get_peft_model(base_model, lora_config)

with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH):
get_peft_model(base_model, lora_config)

def test_get_peft_model_proposed_fix_in_warning_help(self, get_peft_model, lora_config, base_model, recwarn):
peft_model = get_peft_model(base_model, lora_config)
peft_model.unload()
get_peft_model(base_model, lora_config)

warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH)

for warning in recwarn:
if warning_checker.matches(warning):
pytest.fail("Warning raised even though model was unloaded.")

0 comments on commit f32e517

Please sign in to comment.