Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Make PEFT configs forward compatible #2038

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@
from .utils import CONFIG_NAME, PeftType, TaskType


# we expect at least these keys to be present in a PEFT adapter_config.json
MIN_EXPECTED_CONFIG_KEYS = {"peft_type"}


def _check_and_remove_unused_kwargs(cls, kwargs):
"""Make PEFT configs forward-compatible by removing unused kwargs that were added in later PEFT versions.

This assumes that removing the unused kwargs will not affect the default behavior.

Returns the filtered kwargs and the set of removed keys.
"""
# it's not pretty but eh
signature_parameters = inspect.signature(cls.__init__).parameters
unexpected_kwargs = set(kwargs.keys()) - set(signature_parameters.keys())
for key in unexpected_kwargs:
del kwargs[key]
return kwargs, unexpected_kwargs


@dataclass
class PeftConfigMixin(PushToHubMixin):
r"""
Expand Down Expand Up @@ -116,7 +135,31 @@ def from_peft_type(cls, **kwargs):
else:
config_cls = cls

return config_cls(**kwargs)
try:
config = config_cls(**kwargs)
except TypeError as exc:
# Here we potentially handle forward compatibility. Sometimes new keywords are added to configs, which makes
# new configs incompatible with older PEFT versions. We catch these and remove them to allow the program to
# continue, but warn the user about it.

# First check if the error is due to unexpected keyword arguments, we don't want to accidentally catch
# other TypeErrors.
if "got an unexpected keyword argument" not in str(exc):
raise exc

filtered_kwargs, unexpected_kwargs = _check_and_remove_unused_kwargs(cls, kwargs)
if not MIN_EXPECTED_CONFIG_KEYS.issubset(set(filtered_kwargs.keys())):
raise TypeError(f"The config that is trying to be loaded is not a valid {cls.__name__} config.")

warnings.warn(
f"Unexpected keyword arguments {sorted(unexpected_kwargs)} for class {cls.__name__}, these are "
"ignored. This probably means that you're loading a configuration file that was saved using a "
"higher version of the library and additional parameters have been introduced since. It is "
"highly recommended to upgrade the PEFT version before continuing (e.g. by running `pip install "
"-U peft`)."
)
config = config_cls.from_peft_type(**filtered_kwargs)
return config

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
Expand Down
73 changes: 57 additions & 16 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
import os
import pickle
import tempfile
import unittest
import warnings

import pytest
from parameterized import parameterized

from peft import (
AdaLoraConfig,
Expand All @@ -34,6 +32,7 @@
LoKrConfig,
LoraConfig,
MultitaskPromptTuningConfig,
OFTConfig,
PeftConfig,
PeftType,
PolyConfig,
Expand Down Expand Up @@ -69,8 +68,8 @@
)


class PeftConfigTester(unittest.TestCase):
@parameterized.expand(ALL_CONFIG_CLASSES)
class TestPeftConfig:
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_methods(self, config_class):
r"""
Test if all configs have the expected methods. Here we test
Expand All @@ -86,7 +85,7 @@ def test_methods(self, config_class):
assert hasattr(config, "from_pretrained")
assert hasattr(config, "from_json_file")

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_task_type(self, config_class):
config_class(task_type="test")

Expand All @@ -102,7 +101,7 @@ def test_from_peft_type(self):
config = PeftConfig.from_peft_type(peft_type=peft_type)
assert type(config) is expected_cls

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_from_pretrained(self, config_class):
r"""
Test if the config is correctly loaded using:
Expand All @@ -112,7 +111,7 @@ def test_from_pretrained(self, config_class):
# Test we can load config from delta
config_class.from_pretrained(model_name, revision=revision)

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_save_pretrained(self, config_class):
r"""
Test if the config is correctly saved and loaded using
Expand All @@ -125,7 +124,7 @@ def test_save_pretrained(self, config_class):
config_from_pretrained = config_class.from_pretrained(tmp_dirname)
assert config.to_dict() == config_from_pretrained.to_dict()

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_from_json_file(self, config_class):
config = config_class()
with tempfile.TemporaryDirectory() as tmp_dirname:
Expand All @@ -143,7 +142,7 @@ def test_from_json_file(self, config_class):
config_from_json = config_class.from_json_file(config_path)
assert config.to_dict() == config_from_json

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_to_dict(self, config_class):
r"""
Test if the config can be correctly converted to a dict using:
Expand All @@ -152,7 +151,7 @@ def test_to_dict(self, config_class):
config = config_class()
assert isinstance(config.to_dict(), dict)

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_from_pretrained_cache_dir(self, config_class):
r"""
Test if the config is correctly loaded with extra kwargs
Expand All @@ -170,7 +169,7 @@ def test_from_pretrained_cache_dir_remote(self):
PeftConfig.from_pretrained("ybelkada/test-st-lora", cache_dir=tmp_dirname)
assert "models--ybelkada--test-st-lora" in os.listdir(tmp_dirname)

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_save_pretrained_with_runtime_config(self, config_class):
r"""
Test if the config correctly removes runtime config when saving
Expand All @@ -185,7 +184,7 @@ def test_save_pretrained_with_runtime_config(self, config_class):
cfg = config_class.from_pretrained(tmp_dirname)
assert not cfg.runtime_config.ephemeral_gpu_offload

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_set_attributes(self, config_class):
# manually set attributes and check if they are correctly written
config = config_class(peft_type="test")
Expand All @@ -197,21 +196,21 @@ def test_set_attributes(self, config_class):
config_from_pretrained = config_class.from_pretrained(tmp_dirname)
assert config.to_dict() == config_from_pretrained.to_dict()

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_config_copy(self, config_class):
# see https://github.com/huggingface/peft/issues/424
config = config_class()
copied = copy.copy(config)
assert config.to_dict() == copied.to_dict()

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_config_deepcopy(self, config_class):
# see https://github.com/huggingface/peft/issues/424
config = config_class()
copied = copy.deepcopy(config)
assert config.to_dict() == copied.to_dict()

@parameterized.expand(ALL_CONFIG_CLASSES)
@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_config_pickle_roundtrip(self, config_class):
# see https://github.com/huggingface/peft/issues/424
config = config_class()
Expand Down Expand Up @@ -240,7 +239,9 @@ def test_prompt_encoder_warning_num_layers(self):
expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
assert str(record.list[0].message) == expected_msg

@parameterized.expand([LoHaConfig, LoraConfig, IA3Config, BOFTConfig, HRAConfig, VBLoRAConfig])
@pytest.mark.parametrize(
"config_class", [LoHaConfig, LoraConfig, IA3Config, OFTConfig, BOFTConfig, HRAConfig, VBLoRAConfig]
)
def test_save_pretrained_with_target_modules(self, config_class):
# See #1041, #1045
config = config_class(target_modules=["a", "list"])
Expand Down Expand Up @@ -310,3 +311,43 @@ def test_adalora_config_r_warning(self):
# Test that a warning is raised when r != 8 in AdaLoraConfig
with pytest.warns(UserWarning, match="Note that `r` is not used in AdaLora and will be ignored."):
AdaLoraConfig(r=10)

@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_from_pretrained_forward_compatible(self, config_class, tmp_path, recwarn):
"""
Make it possible to load configs that contain unknown keys by ignoring them.

The idea is to make PEFT configs forward-compatible with future versions of the library.
"""
config = config_class()
config.save_pretrained(tmp_path)
# add a spurious key to the config
with open(tmp_path / "adapter_config.json") as f:
config_dict = json.load(f)
config_dict["foobar"] = "baz"
config_dict["spam"] = 123
with open(tmp_path / "adapter_config.json", "w") as f:
json.dump(config_dict, f)

msg = f"Unexpected keyword arguments ['foobar', 'spam'] for class {config_class.__name__}, these are ignored."
config_from_pretrained = config_class.from_pretrained(tmp_path)

assert len(recwarn) == 1
assert recwarn.list[0].message.args[0].startswith(msg)
assert "foo" not in config_from_pretrained.to_dict()
assert "spam" not in config_from_pretrained.to_dict()
assert config.to_dict() == config_from_pretrained.to_dict()
assert isinstance(config_from_pretrained, config_class)

@pytest.mark.parametrize("config_class", ALL_CONFIG_CLASSES)
def test_from_pretrained_sanity_check(self, config_class, tmp_path):
"""Following up on the previous test about forward compatibility, we *don't* want any random json to be accepted as
a PEFT config. There should be a minimum set of required keys.
"""
non_peft_json = {"foo": "bar", "baz": 123}
with open(tmp_path / "adapter_config.json", "w") as f:
json.dump(non_peft_json, f)

msg = f"The config that is trying to be loaded is not a valid {config_class.__name__} config"
with pytest.raises(TypeError, match=msg):
config_class.from_pretrained(tmp_path)
Loading