Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save and load base model with revision #1658

Merged
merged 14 commits into from
May 16, 2024
8 changes: 5 additions & 3 deletions src/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ def from_pretrained(
adapter_name: str = "default",
is_trainable: bool = False,
config: Optional[PeftConfig] = None,
revision: Optional[str] = None,
**kwargs,
):
r"""
A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs
are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and
the config object init.
"""
peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, revision=revision, **kwargs)
base_model_path = peft_config.base_model_name_or_path
base_model_revision = peft_config.revision

task_type = getattr(peft_config, "task_type", None)

Expand Down Expand Up @@ -101,7 +103,7 @@ def from_pretrained(
"Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type."
)

base_model = target_class.from_pretrained(base_model_path, **kwargs)
base_model = target_class.from_pretrained(base_model_path, revision=base_model_revision, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small concern here, in PEFT configs, the revision parameter defaults to None but in transformers, it defaults to main:

https://github.com/huggingface/transformers/blob/9fe3f585bb4ea29f209dc705d269fbe292e1128f/src/transformers/models/auto/auto_factory.py#L135

Honestly, the from_pretrained method is a bit inscrutable to me, so I don't know if this can cause any issues (or might in the future). WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. It shouldn't affect anything but I'll change all revision defaults from None to "main" for consistency


tokenizer_exists = False
if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)):
Expand All @@ -114,7 +116,7 @@ def from_pretrained(
tokenizer_exists = check_file_exists_on_hf_hub(
repo_id=pretrained_model_name_or_path,
filename=TOKENIZER_CONFIG_NAME,
revision=kwargs.get("revision", None),
revision=revision,
repo_type=kwargs.get("repo_type", None),
token=token,
)
Expand Down
4 changes: 2 additions & 2 deletions src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_peft_type(cls, **kwargs):
# TODO: this hack is needed to fix the following issue (on commit 702f937):
# if someone saves a default config and loads it back with `PeftConfig` class it yields to
# not loading the correct config class.

#
# from peft import AdaLoraConfig, PeftConfig
# peft_config = AdaLoraConfig()
# print(peft_config)
Expand Down Expand Up @@ -232,7 +232,7 @@ class PeftConfig(PeftConfigMixin):
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
)
revision: Optional[str] = field(default=None, metadata={"help": "The specific model version to use."})
revision: Optional[str] = field(default=None, metadata={"help": "The specific base model version to use."})
peft_type: Optional[Union[str, PeftType]] = field(default=None, metadata={"help": "Peft type"})
task_type: Optional[Union[str, TaskType]] = field(default=None, metadata={"help": "Task type"})
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
Expand Down
19 changes: 17 additions & 2 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
import warnings
from typing import TYPE_CHECKING, Any, Optional

import torch

Expand Down Expand Up @@ -104,7 +105,11 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:


def get_peft_model(
model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False
model: PreTrainedModel,
peft_config: PeftConfig,
adapter_name: str = "default",
mixed: bool = False,
revision: Optional[str] = None,
) -> PeftModel | PeftMixedModel:
"""
Returns a Peft model object from a model and a config.
Expand All @@ -118,13 +123,23 @@ def get_peft_model(
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
mixed (`bool`, `optional`, defaults to `False`):
Whether to allow mixing different (compatible) adapter types.
revision (`str`, `optional`, defaults to `main`):
The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for
the base model
"""
model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()

peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

if revision is not None:
if peft_config.revision is not None and peft_config.revision != revision:
warnings.warn(
f"peft config has already set base model revision to {peft_config.revision}, overwriting with revision {revision}"
)
peft_config.revision = revision

if mixed:
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)

Expand Down
53 changes: 52 additions & 1 deletion tests/test_hub_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest

import torch
from transformers import AutoModelForCausalLM

from peft import PeftConfig, PeftModel
from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig, PeftModel, get_peft_model


PEFT_MODELS_TO_TEST = [("peft-internal-testing/test-lora-subfolder", "test")]
Expand All @@ -35,3 +37,52 @@ def test_subfolder(self):
model = PeftModel.from_pretrained(model, model_id, subfolder=subfolder)

assert isinstance(model, PeftModel)


class TestBaseModelRevision:
def test_save_and_load_base_model_revision(self, tmp_path):
r"""
Test saving a PeftModel with a base model revision and loading with AutoPeftModel to recover the same base
model
"""
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.0)
test_inputs = torch.arange(10).reshape(-1, 1)

base_model_id = "peft-internal-testing/tiny-random-BertModel"
revision = "v2.0.0"

base_model_revision = AutoModelForCausalLM.from_pretrained(base_model_id, revision=revision).eval()
peft_model_revision = get_peft_model(base_model_revision, lora_config, revision=revision)
output_revision = peft_model_revision(test_inputs).logits

# sanity check: the model without revision should be different
base_model_no_revision = AutoModelForCausalLM.from_pretrained(base_model_id, revision="main").eval()
# we need a copy of the config because otherwise, we are changing in-place the `revision` of the previous config and model
lora_config_no_revision = copy.deepcopy(lora_config)
lora_config_no_revision.revision = "main"
peft_model_no_revision = get_peft_model(base_model_no_revision, lora_config_no_revision, revision="main")
output_no_revision = peft_model_no_revision(test_inputs).logits
assert not torch.allclose(output_no_revision, output_revision)

# check that if we save and load the model, the output corresponds to the one with revision
peft_model_revision.save_pretrained(tmp_path / "peft_model_revision")
peft_model_revision_loaded = AutoPeftModelForCausalLM.from_pretrained(tmp_path / "peft_model_revision").eval()

assert peft_model_revision_loaded.peft_config["default"].revision == revision

output_revision_loaded = peft_model_revision_loaded(test_inputs).logits
assert torch.allclose(output_revision, output_revision_loaded)

def test_load_different_peft_and_base_model_revision(self, tmp_path):
r"""
Test loading an AutoPeftModel from the hub where the base model revision and peft revision differ
"""
base_model_id = "hf-internal-testing/tiny-random-BertModel"
base_model_revision = None
peft_model_id = "peft-internal-testing/tiny-random-BertModel-lora"
peft_model_revision = "v1.2.3"

peft_model = AutoPeftModelForCausalLM.from_pretrained(peft_model_id, revision=peft_model_revision).eval()

assert peft_model.peft_config["default"].base_model_name_or_path == base_model_id
assert peft_model.peft_config["default"].revision == base_model_revision
13 changes: 13 additions & 0 deletions tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest
import torch
from torch import nn
from transformers import AutoModelForCausalLM

from peft import LoraConfig, get_peft_model

Expand Down Expand Up @@ -73,3 +74,15 @@ def test_modules_to_save_targets_module_dict_raises(cls):
msg = "modules_to_save cannot be applied to modules of type"
with pytest.raises(TypeError, match=msg):
get_peft_model(model=model, peft_config=peft_config)


def test_get_peft_model_revision_warning(tmp_path):
base_model_id = "peft-internal-testing/tiny-random-BertModel"
base_revision = "v2.0.0"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, revision=base_revision).eval()
lora_config = LoraConfig(revision=base_revision)

overwrite_revision = "main"
overwrite_warning = f"peft config has already set base model revision to {base_revision}, overwriting with revision {overwrite_revision}"
with pytest.warns(UserWarning, match=overwrite_warning):
_ = get_peft_model(base_model, lora_config, revision=overwrite_revision)
Loading