Skip to content

Commit

Permalink
[PEFT] Allow PEFT model dict to be loaded (#25721)
Browse files Browse the repository at this point in the history
* Allow PEFT model dict to be loaded

* make style

* make style

* Apply suggestions from code review

* address comments

* fixup

* final change

* added tests

* fix test

* better logic for handling if adapter has been loaded

* Update tests/peft_integration/test_peft_integration.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 15, 2023
1 parent 8b13471 commit 0a55d9f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 22 deletions.
66 changes: 44 additions & 22 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from ..utils import (
check_peft_version,
find_adapter_config_file,
is_accelerate_available,
is_peft_available,
is_torch_available,
logging,
)

Expand All @@ -30,6 +31,11 @@
# Minimum PEFT version supported for the integration
MIN_PEFT_VERSION = "0.5.0"

if TYPE_CHECKING:
if is_torch_available():
import torch


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -61,14 +67,16 @@ class PeftAdapterMixin:

def load_adapter(
self,
peft_model_id: str,
peft_model_id: Optional[str] = None,
adapter_name: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
device_map: Optional[str] = "auto",
max_memory: Optional[str] = None,
offload_folder: Optional[str] = None,
offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
) -> None:
"""
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
Expand All @@ -77,7 +85,7 @@ def load_adapter(
Requires peft as a backend to load the adapter weights.
Args:
peft_model_id (`str`):
peft_model_id (`str`, *optional*):
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
and adapter weights.
adapter_name (`str`, *optional*):
Expand Down Expand Up @@ -114,6 +122,12 @@ def load_adapter(
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_index (`int`, `optional`):
`offload_index` argument to be passed to `accelerate.dispatch_model` method.
peft_config (`Dict[str, Any]`, *optional*):
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
methods. This argument is used in case users directly pass PEFT state dicts
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

Expand All @@ -122,33 +136,41 @@ def load_adapter(
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
if self._hf_peft_config_loaded and adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")

adapter_config_file = find_adapter_config_file(
peft_model_id,
revision=revision,
token=token,
)

if adapter_config_file is None:
if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
)

loaded_peft_config = PeftConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=token,
)
if peft_config is None:
adapter_config_file = find_adapter_config_file(
peft_model_id,
revision=revision,
token=token,
)

if adapter_config_file is None:
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
)

peft_config = PeftConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=token,
)

# Create and add fresh new adapters into the model.
inject_adapter_in_model(loaded_peft_config, self, adapter_name)
inject_adapter_in_model(peft_config, self, adapter_name)

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True

adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)

# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {}
Expand Down
32 changes: 32 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tempfile
import unittest

from huggingface_hub import hf_hub_download

from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import is_torch_available
Expand Down Expand Up @@ -300,3 +302,33 @@ def test_peft_pipeline(self):
for model_id in self.peft_test_model_ids:
pipe = pipeline("text-generation", model_id)
_ = pipe("Hello")

def test_peft_add_adapter_with_state_dict(self):
"""
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
add_adapter works as expected with a state_dict being passed.
"""
from peft import LoraConfig

dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig(init_lora_weights=False)

with self.assertRaises(ValueError):
model.load_adapter(peft_model_id=None)

state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")

dummy_state_dict = torch.load(state_dict_path)

model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
with self.assertRaises(ValueError):
model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None))
self.assertTrue(self._check_lora_correctly_converted(model))

# dummy generation
_ = model.generate(input_ids=dummy_input)

0 comments on commit 0a55d9f

Please sign in to comment.