Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save and load base model with revision #1658

Merged
merged 14 commits into from
May 16, 2024
3 changes: 2 additions & 1 deletion src/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def from_pretrained(
"""
peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
base_model_path = peft_config.base_model_name_or_path
base_model_revision = peft_config.revision

task_type = getattr(peft_config, "task_type", None)

Expand Down Expand Up @@ -101,7 +102,7 @@ def from_pretrained(
"Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type."
)

base_model = target_class.from_pretrained(base_model_path, **kwargs)
base_model = target_class.from_pretrained(base_model_path, revision=base_model_revision, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small concern here, in PEFT configs, the revision parameter defaults to None but in transformers, it defaults to main:

https://github.com/huggingface/transformers/blob/9fe3f585bb4ea29f209dc705d269fbe292e1128f/src/transformers/models/auto/auto_factory.py#L135

Honestly, the from_pretrained method is a bit inscrutable to me, so I don't know if this can cause any issues (or might in the future). WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. It shouldn't affect anything but I'll change all revision defaults from None to "main" for consistency


tokenizer_exists = False
if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)):
Expand Down
3 changes: 1 addition & 2 deletions src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def from_peft_type(cls, **kwargs):
# TODO: this hack is needed to fix the following issue (on commit 702f937):
# if someone saves a default config and loads it back with `PeftConfig` class it yields to
# not loading the correct config class.

# from peft import AdaLoraConfig, PeftConfig
# peft_config = AdaLoraConfig()
# print(peft_config)
Expand Down Expand Up @@ -232,7 +231,7 @@ class PeftConfig(PeftConfigMixin):
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
)
revision: Optional[str] = field(default=None, metadata={"help": "The specific model version to use."})
revision: Optional[str] = field(default=None, metadata={"help": "The specific base model version to use."})
peft_type: Optional[Union[str, PeftType]] = field(default=None, metadata={"help": "Peft type"})
task_type: Optional[Union[str, TaskType]] = field(default=None, metadata={"help": "Task type"})
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
Expand Down
1 change: 1 addition & 0 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def get_peft_model(
model_config = model_config.to_dict()

peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
peft_config.revision = model.__dict__.get("revision", None)

if mixed:
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
Expand Down
6 changes: 6 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ def save_pretrained(
if peft_config.is_prompt_learning
else self.base_model.model.__dict__.get("name_or_path", None)
)
if peft_config.revision is None:
peft_config.revision = (
self.base_model.__dict__.get("revision", None)
if peft_config.is_prompt_learning
else self.base_model.model.__dict__.get("revision", None)
)
inference_mode = peft_config.inference_mode
peft_config.inference_mode = True

Expand Down
36 changes: 35 additions & 1 deletion tests/test_hub_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.
import unittest

import torch
from transformers import AutoModelForCausalLM

from peft import PeftConfig, PeftModel
from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig, PeftModel, get_peft_model


PEFT_MODELS_TO_TEST = [("peft-internal-testing/test-lora-subfolder", "test")]

BASE_REVISION_MODELS_TO_TEST = [("peft-internal-testing/tiny-random-BertModel", "v2.0.0")]


class PeftHubFeaturesTester(unittest.TestCase):
def test_subfolder(self):
Expand All @@ -35,3 +38,34 @@ def test_subfolder(self):
model = PeftModel.from_pretrained(model, model_id, subfolder=subfolder)

assert isinstance(model, PeftModel)


class TestBaseModelRevision:
def test_save_and_load_base_model_revision(self, tmp_path):
r"""
Test if subfolder argument works as expected
"""
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.0, init_lora_weights=False)
test_inputs = torch.arange(10).reshape(-1, 1)

for model_id, revision in BASE_REVISION_MODELS_TO_TEST:
original_base_model = AutoModelForCausalLM.from_pretrained(model_id, revision="main").eval()
original_peft_model = get_peft_model(original_base_model, lora_config)
original_peft_sum = original_peft_model(test_inputs).logits.sum()

revised_base_model = AutoModelForCausalLM.from_pretrained(model_id, revision=revision).eval()
revised_peft_model = get_peft_model(revised_base_model, lora_config)
revised_peft_sum = revised_peft_model(test_inputs).logits.sum()

assert not torch.eq(
original_peft_sum, revised_peft_sum
), f"revision 'main' and {revision} of base model {model_id} must differ"

revised_peft_model.save_pretrained(tmp_path / f"base_{revision}_model")

reload_revised_peft_model = AutoPeftModelForCausalLM.from_pretrained(
tmp_path / f"base_{revision}_model"
).eval()
reload_revised_sum = reload_revised_peft_model(test_inputs).logits.sum()

assert torch.eq(reload_revised_sum, reload_revised_sum)
Loading