Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed adapter models #1163

Merged
merged 15 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ any GPU memory savings. Please refer issue [[FSDP] FSDP with CPU offload consume

## 🤗 PEFT as a utility library

### Injecting adapters directly into the model

Inject trainable adapters on any `torch` model using `inject_adapter_in_model` method. Note the method will make no further change to the model.

```python
Expand Down Expand Up @@ -396,6 +398,35 @@ dummy_inputs = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])
dummy_outputs = model(dummy_inputs)
```

### Mixing different adapter types

Ususally, it is not possible to combine different adapter types in the same model, e.g. combining LoRA with AdaLoRA, LoHa, or LoKr. Using a mixed model, this can, however, be achieved:

```python
from peft import PeftMixedModel

model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM").eval()
peft_model = PeftMixedModel.from_pretrained(model, <path-to-adapter-0>, "adapter0")
peft_model.load_adapter(<path-to-adapter-1>, "adapter1")
peft_model.set_adapter(["adapter0", "adapter1"])
result = peft_model(**inputs)
```

The main intent is to load already trained adapters and use this only for inference. However, it is also possible to create a PEFT model for training by passing `mixed=True` to `get_peft_model`:

```python
from peft import get_peft_model, LoraConfig, LoKrConfig

base_model = ...
config0 = LoraConfig(...)
config1 = LoKrConfig(...)
peft_model = get_peft_model(base_model, config0, "adapter0", mixed=True)
peft_model.add_adapter(config1, "adapter1")
peft_model.set_adapter(["adapter0", "adapter1"])
for batch in dataloader:
...
```

## Contributing

If you would like to contribute to PEFT, please check out our [contributing guide](https://huggingface.co/docs/peft/developer_guides/contributing).
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
title: Working with custom models
- local: developer_guides/low_level_api
title: PEFT low level API
- local: developer_guides/mixed_models
title: Mixing different adapter types
- local: developer_guides/contributing
title: Contributing to PEFT
- local: developer_guides/troubleshooting
Expand Down
39 changes: 39 additions & 0 deletions docs/source/developer_guides/mixed_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Working with mixed adapter types

Normally, it is not possible to mix different adapter types in 🤗 PEFT. For example, even though it is possible to create a PEFT model that has two different LoRA adapters (that can have different config options), it is not possible to combine a LoRA adapter with a LoHa adapter. However, by using a mixed model, this works as long as the adapter types are compatible.

## Loading different adapter types into a PEFT model

To load different adapter types into a PEFT model, proceed the same as if you were loading two adapters of the same type, but use `PeftMixedModel` instead of `PeftModel`:

```py
from peft import PeftMixedModel

base_model = ... # load the base model, e.g. from transformers
# load first adapter, which will be called "default"
peft_model = PeftMixedModel.from_pretrained(base_model, <path_to_adapter1>)
peft_model.load_adapter(<path_to_adapter2>, adapter_name="other")
peft_model.set_adapter(["default", "other"])
```

The last line is necessary if you want to activate both adapters, otherwise, only the first adapter would be active. Of course, you can add more different adapters by calling `add_adapter` repeatedly.

Currently, the main purpose of mixed adapter types is to combine trained adapters for inference. Although it is technically also possible to train a mixed adapter model, this has not been tested and is not recommended.

## Tips

- Not all adapter types can be combined. See `peft.tuners.mixed.COMPATIBLE_TUNER_TYPES` for a list of compatible types. An error will be raised if you are trying to combine incompatible adapter types.
- It is possible to mix multiple adapters of the same type. This can be useful to combine adapters with very different configs.
- If you want to combine a lot of different adapters, it is most performant to add the same types of adapters consecutively. E.g., add LoRA1, LoRA2, LoHa1, LoHa2 in this order, instead of LoRA1, LoHa1, LoRA2, LoHa2. The order will make a difference for the outcome in most cases, but since no order is better a priori, it is best to choose the order that is most performant.
1 change: 1 addition & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
get_peft_model,
inject_adapter_in_model,
)
from .mixed_model import PeftMixedModel
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
Expand Down
19 changes: 16 additions & 3 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch

from .config import PeftConfig
from .mixed_model import PeftMixedModel
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
Expand Down Expand Up @@ -95,22 +96,34 @@ def get_peft_config(config_dict: Dict[str, Any]) -> PeftConfig:
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)


def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
def get_peft_model(
model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False
) -> PeftModel | PeftMixedModel:
"""
Returns a Peft model object from a model and a config.

Args:
model ([`transformers.PreTrainedModel`]): Model to be wrapped.
peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model.
model ([`transformers.PreTrainedModel`]):
Model to be wrapped.
peft_config ([`PeftConfig`]):
Configuration object containing the parameters of the Peft model.
adapter_name (`str`, `optional`, defaults to `"default"`):
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
mixed (`bool`, `optional`, defaults to `False`):
Whether to allow mixing different (compatible) adapter types.
"""
model_config = getattr(model, "config", {"model_type": "custom"})
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()

peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

if mixed:
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)

if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
return PeftModel(model, peft_config, adapter_name=adapter_name)

if peft_config.is_prompt_learning:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
Expand Down
Loading
Loading