Skip to content

Commit

Permalink
Introducing AutoPeftModelForxxx (#694)
Browse files Browse the repository at this point in the history
* working v1 for LMs

* added tests.

* added documentation.

* fixed ruff issues.

* added `AutoPeftModelForFeatureExtraction` .

* replace with `TypeError`

* address last comments

* added comment.
  • Loading branch information
younesbelkada authored Jul 14, 2023
1 parent fa5957f commit 0675541
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 1 deletion.
18 changes: 18 additions & 0 deletions docs/source/quicktour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ Easily load your model for inference using the [`~transformers.PreTrainedModel.f
'complaint'
```

## Easy loading with Auto classes

If you have saved your adapter locally or on the Hub, you can leverage the `AutoPeftModelForxxx` classes and load any PEFT model with a single line of code:

```diff
- from peft import PeftConfig, PeftModel
- from transformers import AutoModelForCausalLM
+ from peft import AutoPeftModelForCausalLM

- peft_config = PeftConfig.from_pretrained("ybelkada/opt-350m-lora")
- base_model_path = peft_config.base_model_name_or_path
- transformers_model = AutoModelForCausalLM.from_pretrained(base_model_path)
- peft_model = PeftModel.from_pretrained(transformers_model, peft_config)
+ peft_model = AutoPeftModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora")
```

Currently, supported auto classes are: `AutoPeftModelForCausalLM`, `AutoPeftModelForSequenceClassification`, `AutoPeftModelForSeq2SeqLM`, `AutoPeftModelForTokenClassification`, `AutoPeftModelForQuestionAnswering` and `AutoPeftModelForFeatureExtraction`.

## Next steps

Now that you've seen how to train a model with one of the 🤗 PEFT methods, we encourage you to try out some of the other methods like prompt tuning. The steps are very similar to the ones shown in this quickstart; prepare a [`PeftConfig`] for a 🤗 PEFT method, and use the `get_peft_model` to create a [`PeftModel`] from the configuration and base model. Then you can train it however you like!
Expand Down
8 changes: 8 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@

__version__ = "0.4.0.dev0"

from .auto import (
AutoPeftModelForCausalLM,
AutoPeftModelForSequenceClassification,
AutoPeftModelForSeq2SeqLM,
AutoPeftModelForTokenClassification,
AutoPeftModelForQuestionAnswering,
AutoPeftModelForFeatureExtraction,
)
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model
from .peft_model import (
PeftModel,
Expand Down
118 changes: 118 additions & 0 deletions src/peft/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Optional

from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)

from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
from .peft_model import (
PeftModelForCausalLM,
PeftModelForFeatureExtraction,
PeftModelForQuestionAnswering,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from .utils import PeftConfig


class _BaseAutoPeftModel:
_target_class = None
_target_peft_class = None

def __init__(self, *args, **kwargs):
# For consistency with transformers: https://github.com/huggingface/transformers/blob/91d7df58b6537d385e90578dac40204cb550f706/src/transformers/models/auto/auto_factory.py#L400
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_config(config)` methods."
)

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
adapter_name: str = "default",
is_trainable: bool = False,
config: Optional[PeftConfig] = None,
**kwargs,
):
r"""
A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs
are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and
the config object init.
"""
peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
base_model_path = peft_config.base_model_name_or_path

transformers_model = cls._target_class.from_pretrained(base_model_path, **kwargs)

task_type = getattr(peft_config, "task_type", None)
if task_type is not None:
expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type]
if cls._target_peft_class.__name__ != expected_target_class.__name__:
raise ValueError(
f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }"
" make sure that you are loading the correct model for your task type."
)

return cls._target_peft_class.from_pretrained(
transformers_model,
pretrained_model_name_or_path,
adapter_name=adapter_name,
is_trainable=is_trainable,
config=config,
**kwargs,
)


class AutoPeftModelForCausalLM(_BaseAutoPeftModel):
_target_class = AutoModelForCausalLM
_target_peft_class = PeftModelForCausalLM


class AutoPeftModelForSeq2SeqLM(_BaseAutoPeftModel):
_target_class = AutoModelForSeq2SeqLM
_target_peft_class = PeftModelForSeq2SeqLM


class AutoPeftModelForSequenceClassification(_BaseAutoPeftModel):
_target_class = AutoModelForSequenceClassification
_target_peft_class = PeftModelForSequenceClassification


class AutoPeftModelForTokenClassification(_BaseAutoPeftModel):
_target_class = AutoModelForTokenClassification
_target_peft_class = PeftModelForTokenClassification


class AutoPeftModelForQuestionAnswering(_BaseAutoPeftModel):
_target_class = AutoModelForQuestionAnswering
_target_peft_class = PeftModelForQuestionAnswering


class AutoPeftModelForFeatureExtraction(_BaseAutoPeftModel):
_target_class = AutoModel
_target_peft_class = PeftModelForFeatureExtraction
2 changes: 1 addition & 1 deletion src/peft/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs
else pretrained_model_name_or_path
)

hf_hub_download_kwargs, class_kwargs, other_kwargs = cls._split_kwargs(kwargs)
hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)

if os.path.isfile(os.path.join(path, CONFIG_NAME)):
config_file = os.path.join(path, CONFIG_NAME)
Expand Down
169 changes: 169 additions & 0 deletions tests/test_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest

import torch

from peft import (
AutoPeftModelForCausalLM,
AutoPeftModelForFeatureExtraction,
AutoPeftModelForQuestionAnswering,
AutoPeftModelForSeq2SeqLM,
AutoPeftModelForSequenceClassification,
AutoPeftModelForTokenClassification,
PeftModelForCausalLM,
PeftModelForFeatureExtraction,
PeftModelForQuestionAnswering,
PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)


class PeftAutoModelTester(unittest.TestCase):
def test_peft_causal_lm(self):
model_id = "peft-internal-testing/tiny-OPTForCausalLM-lora"
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForCausalLM))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model = AutoPeftModelForCausalLM.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForCausalLM))

# check if kwargs are passed correctly
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
self.assertTrue(isinstance(model, PeftModelForCausalLM))
self.assertTrue(model.base_model.lm_head.weight.dtype == torch.bfloat16)

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)

def test_peft_seq2seq_lm(self):
model_id = "peft-internal-testing/tiny_T5ForSeq2SeqLM-lora"
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForSeq2SeqLM))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForSeq2SeqLM))

# check if kwargs are passed correctly
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
self.assertTrue(isinstance(model, PeftModelForSeq2SeqLM))
self.assertTrue(model.base_model.lm_head.weight.dtype == torch.bfloat16)

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16)

def test_peft_sequence_cls(self):
model_id = "peft-internal-testing/tiny_OPTForSequenceClassification-lora"
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForSequenceClassification))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model = AutoPeftModelForSequenceClassification.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForSequenceClassification))

# check if kwargs are passed correctly
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)
self.assertTrue(isinstance(model, PeftModelForSequenceClassification))
self.assertTrue(model.score.original_module.weight.dtype == torch.bfloat16)

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForSequenceClassification.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
)

def test_peft_token_classification(self):
model_id = "peft-internal-testing/tiny_GPT2ForTokenClassification-lora"
model = AutoPeftModelForTokenClassification.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForTokenClassification))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model = AutoPeftModelForTokenClassification.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForTokenClassification))

# check if kwargs are passed correctly
model = AutoPeftModelForTokenClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16)
self.assertTrue(isinstance(model, PeftModelForTokenClassification))
self.assertTrue(model.base_model.classifier.original_module.weight.dtype == torch.bfloat16)

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForTokenClassification.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
)

def test_peft_question_answering(self):
model_id = "peft-internal-testing/tiny_OPTForQuestionAnswering-lora"
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForQuestionAnswering))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForQuestionAnswering))

# check if kwargs are passed correctly
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=torch.bfloat16)
self.assertTrue(isinstance(model, PeftModelForQuestionAnswering))
self.assertTrue(model.base_model.qa_outputs.original_module.weight.dtype == torch.bfloat16)

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForQuestionAnswering.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
)

def test_peft_feature_extraction(self):
model_id = "peft-internal-testing/tiny_OPTForFeatureExtraction-lora"
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForFeatureExtraction))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id)
self.assertTrue(isinstance(model, PeftModelForFeatureExtraction))

# check if kwargs are passed correctly
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id, torch_dtype=torch.bfloat16)
self.assertTrue(isinstance(model, PeftModelForFeatureExtraction))
self.assertTrue(model.base_model.model.decoder.embed_tokens.weight.dtype == torch.bfloat16)

adapter_name = "default"
is_trainable = False
# This should work
_ = AutoPeftModelForFeatureExtraction.from_pretrained(
model_id, adapter_name, is_trainable, torch_dtype=torch.bfloat16
)

0 comments on commit 0675541

Please sign in to comment.