forked from huggingface/peft
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX huggingface#2295: Warn when user reloads modified model
When modifying a model with `get_peft_model` that was already modified in the same way, even specifying a different config may not change the trainable parameter count, e.g. when specifying target modules that are only a subset of the previous target modules. With this patch a warning will be issued with a hint to `.unload()` when calling `get_peft_model` on an already modified model.
- Loading branch information
nemo
committed
Jan 6, 2025
1 parent
6a533b7
commit f32e517
Showing
2 changed files
with
55 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
import torch | ||
|
||
|
||
class TestGetPeftModel: | ||
RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*" | ||
|
||
@pytest.fixture | ||
def get_peft_model(self): | ||
from peft import get_peft_model | ||
|
||
return get_peft_model | ||
|
||
@pytest.fixture | ||
def lora_config(self): | ||
from peft import LoraConfig | ||
|
||
return LoraConfig(target_modules="0") | ||
|
||
@pytest.fixture | ||
def base_model(self): | ||
return torch.nn.Sequential(torch.nn.Linear(10, 2)) | ||
|
||
def test_get_peft_model_warns_when_reloading_model(self, get_peft_model, lora_config, base_model): | ||
get_peft_model(base_model, lora_config) | ||
|
||
with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH): | ||
get_peft_model(base_model, lora_config) | ||
|
||
def test_get_peft_model_proposed_fix_in_warning_help(self, get_peft_model, lora_config, base_model, recwarn): | ||
peft_model = get_peft_model(base_model, lora_config) | ||
peft_model.unload() | ||
get_peft_model(base_model, lora_config) | ||
|
||
warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH) | ||
|
||
for warning in recwarn: | ||
if warning_checker.matches(warning): | ||
pytest.fail("Warning raised even though model was unloaded.") |