Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Allow PEFT model dict to be loaded #25721

Merged
merged 13 commits into from
Sep 15, 2023
66 changes: 44 additions & 22 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from ..utils import (
check_peft_version,
find_adapter_config_file,
is_accelerate_available,
is_peft_available,
is_torch_available,
logging,
)

Expand All @@ -30,6 +31,11 @@
# Minimum PEFT version supported for the integration
MIN_PEFT_VERSION = "0.5.0"

if TYPE_CHECKING:
if is_torch_available():
import torch


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -61,14 +67,16 @@ class PeftAdapterMixin:

def load_adapter(
self,
peft_model_id: str,
peft_model_id: Optional[str] = None,
adapter_name: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
device_map: Optional[str] = "auto",
max_memory: Optional[str] = None,
offload_folder: Optional[str] = None,
offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
) -> None:
"""
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
Expand All @@ -77,7 +85,7 @@ def load_adapter(
Requires peft as a backend to load the adapter weights.

Args:
peft_model_id (`str`):
peft_model_id (`str`, *optional*):
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
and adapter weights.
adapter_name (`str`, *optional*):
Expand Down Expand Up @@ -114,6 +122,12 @@ def load_adapter(
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_index (`int`, `optional`):
`offload_index` argument to be passed to `accelerate.dispatch_model` method.
peft_config (`Dict[str, Any]`, *optional*):
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
methods. This argument is used in case users directly pass PEFT state dicts
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
"""
check_peft_version(min_version=MIN_PEFT_VERSION)

Expand All @@ -122,33 +136,41 @@ def load_adapter(
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
if self._hf_peft_config_loaded and adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")

adapter_config_file = find_adapter_config_file(
peft_model_id,
revision=revision,
token=token,
)

if adapter_config_file is None:
if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
)

loaded_peft_config = PeftConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=token,
)
if peft_config is None:
adapter_config_file = find_adapter_config_file(
peft_model_id,
revision=revision,
token=token,
)

if adapter_config_file is None:
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
)

peft_config = PeftConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=token,
)

# Create and add fresh new adapters into the model.
inject_adapter_in_model(loaded_peft_config, self, adapter_name)
inject_adapter_in_model(peft_config, self, adapter_name)

if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True

adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)

# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {}
Expand Down
32 changes: 32 additions & 0 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tempfile
import unittest

from huggingface_hub import hf_hub_download

from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import is_torch_available
Expand Down Expand Up @@ -300,3 +302,33 @@ def test_peft_pipeline(self):
for model_id in self.peft_test_model_ids:
pipe = pipeline("text-generation", model_id)
_ = pipe("Hello")

def test_peft_add_adapter_with_state_dict(self):
"""
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
add_adapter works as expected with a state_dict being passed.
"""
from peft import LoraConfig

dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)

for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)

peft_config = LoraConfig(init_lora_weights=False)

with self.assertRaises(ValueError):
model.load_adapter(peft_model_id=None)

state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")

dummy_state_dict = torch.load(state_dict_path)

model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
with self.assertRaises(ValueError):
model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None))
self.assertTrue(self._check_lora_correctly_converted(model))

# dummy generation
_ = model.generate(input_ids=dummy_input)