Skip to content

Commit

Permalink
[Feature] Save only selected adapters for LoRA (#705)
Browse files Browse the repository at this point in the history
* v1 working for LoRA

* more checks

* fix prompt learning issues

* fix failing test

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* fixed indentation

* move the check above

* added tests for adaption prompt, enc-dec and feature extraction

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
  • Loading branch information
younesbelkada and BenjaminBossan authored Jul 14, 2023
1 parent 86ad5ce commit 5a0e19d
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 49 deletions.
44 changes: 1 addition & 43 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
PromptEncoderConfig,
PromptTuningConfig,
)
from .utils import PromptLearningConfig
from .utils import PromptLearningConfig, _prepare_prompt_learning_config


if TYPE_CHECKING:
Expand Down Expand Up @@ -75,48 +75,6 @@ def get_peft_config(config_dict: Dict[str, Any]):
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)


def _prepare_prompt_learning_config(peft_config: PeftConfig, model_config: Dict[str, Any]):
if peft_config.num_layers is None:
if "num_hidden_layers" in model_config:
num_layers = model_config["num_hidden_layers"]
elif "num_layers" in model_config:
num_layers = model_config["num_layers"]
elif "n_layer" in model_config:
num_layers = model_config["n_layer"]
else:
raise ValueError("Please specify `num_layers` in `peft_config`")
peft_config.num_layers = num_layers

if peft_config.token_dim is None:
if "hidden_size" in model_config:
token_dim = model_config["hidden_size"]
elif "n_embd" in model_config:
token_dim = model_config["n_embd"]
elif "d_model" in model_config:
token_dim = model_config["d_model"]
else:
raise ValueError("Please specify `token_dim` in `peft_config`")
peft_config.token_dim = token_dim

if peft_config.num_attention_heads is None:
if "num_attention_heads" in model_config:
num_attention_heads = model_config["num_attention_heads"]
elif "n_head" in model_config:
num_attention_heads = model_config["n_head"]
elif "num_heads" in model_config:
num_attention_heads = model_config["num_heads"]
elif "encoder_attention_heads" in model_config:
num_attention_heads = model_config["encoder_attention_heads"]
else:
raise ValueError("Please specify `num_attention_heads` in `peft_config`")
peft_config.num_attention_heads = num_attention_heads

if getattr(peft_config, "encoder_hidden_size", None) is None:
setattr(peft_config, "encoder_hidden_size", peft_config.token_dim)

return peft_config


def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
"""
Returns a Peft model object from a model and a config.
Expand Down
42 changes: 36 additions & 6 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from accelerate import dispatch_model, infer_auto_device_map
Expand Down Expand Up @@ -53,6 +53,7 @@
PeftType,
PromptLearningConfig,
TaskType,
_prepare_prompt_learning_config,
_set_adapter,
_set_trainable,
add_library_to_model_card,
Expand Down Expand Up @@ -118,7 +119,13 @@ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name
if getattr(model, "is_gradient_checkpointing", True):
model = self._prepare_model_for_gradient_checkpointing(model)

def save_pretrained(self, save_directory: str, safe_serialization: bool = False, **kwargs: Any):
def save_pretrained(
self,
save_directory: str,
safe_serialization: bool = False,
selected_adapters: Optional[List[str]] = None,
**kwargs: Any,
):
r"""
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
Expand All @@ -133,10 +140,24 @@ def save_pretrained(self, save_directory: str, safe_serialization: bool = False,
"""
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")

if selected_adapters is None:
selected_adapters = list(self.peft_config.keys())
else:
if any(
selected_adapter_name not in list(self.peft_config.keys())
for selected_adapter_name in selected_adapters
):
raise ValueError(
f"You passed an invalid `selected_adapters` arguments, current supported adapter names are"
f" {list(self.peft_config.keys())} - got {selected_adapters}."
)

os.makedirs(save_directory, exist_ok=True)
self.create_or_update_model_card(save_directory)

for adapter_name, peft_config in self.peft_config.items():
for adapter_name in selected_adapters:
peft_config = self.peft_config[adapter_name]
# save only the trainable weights
output_state_dict = get_peft_model_state_dict(
self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name
Expand All @@ -146,7 +167,9 @@ def save_pretrained(self, save_directory: str, safe_serialization: bool = False,

if safe_serialization:
safe_save_file(
output_state_dict, os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), metadata={"format": "pt"}
output_state_dict,
os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
Expand Down Expand Up @@ -234,8 +257,9 @@ def from_pretrained(

def _setup_prompt_encoder(self, adapter_name: str):
config = self.peft_config[adapter_name]
self.prompt_encoder = torch.nn.ModuleDict({})
self.prompt_tokens = {}
if not hasattr(self, "prompt_encoder"):
self.prompt_encoder = torch.nn.ModuleDict({})
self.prompt_tokens = {}
transformer_backbone = None
for name, module in self.base_model.named_children():
for param in module.parameters():
Expand Down Expand Up @@ -412,6 +436,12 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
)
self.peft_config[adapter_name] = peft_config
if isinstance(peft_config, PromptLearningConfig):
if hasattr(self.config, "to_dict"):
dict_config = self.config.to_dict()
else:
dict_config = self.config

peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
self._setup_prompt_encoder(adapter_name)
else:
self.base_model.add_adapter(adapter_name, peft_config)
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
_set_adapter,
_freeze_adapter,
ModulesToSaveWrapper,
_prepare_prompt_learning_config,
)
from .hub_utils import hub_file_exists
from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
42 changes: 42 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,48 @@ def _set_adapter(model, adapter_name):
module.active_adapter = adapter_name


def _prepare_prompt_learning_config(peft_config, model_config):
if peft_config.num_layers is None:
if "num_hidden_layers" in model_config:
num_layers = model_config["num_hidden_layers"]
elif "num_layers" in model_config:
num_layers = model_config["num_layers"]
elif "n_layer" in model_config:
num_layers = model_config["n_layer"]
else:
raise ValueError("Please specify `num_layers` in `peft_config`")
peft_config.num_layers = num_layers

if peft_config.token_dim is None:
if "hidden_size" in model_config:
token_dim = model_config["hidden_size"]
elif "n_embd" in model_config:
token_dim = model_config["n_embd"]
elif "d_model" in model_config:
token_dim = model_config["d_model"]
else:
raise ValueError("Please specify `token_dim` in `peft_config`")
peft_config.token_dim = token_dim

if peft_config.num_attention_heads is None:
if "num_attention_heads" in model_config:
num_attention_heads = model_config["num_attention_heads"]
elif "n_head" in model_config:
num_attention_heads = model_config["n_head"]
elif "num_heads" in model_config:
num_attention_heads = model_config["num_heads"]
elif "encoder_attention_heads" in model_config:
num_attention_heads = model_config["encoder_attention_heads"]
else:
raise ValueError("Please specify `num_attention_heads` in `peft_config`")
peft_config.num_attention_heads = num_attention_heads

if getattr(peft_config, "encoder_hidden_size", None) is None:
setattr(peft_config, "encoder_hidden_size", peft_config.token_dim)

return peft_config


def fsdp_auto_wrap_policy(model):
import functools
import os
Expand Down
50 changes: 50 additions & 0 deletions tests/test_adaption_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,56 @@ def test_save_pretrained(self) -> None:
# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

def test_save_pretrained_selected_adapters(self) -> None:
seed = 420
torch.manual_seed(seed)
model = LlamaForCausalLM(self._create_test_llama_config())
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
model = get_peft_model(model, config)
model = model.to(self.torch_device)

new_adapter_config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
model.add_adapter("new_adapter", new_adapter_config)

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

torch.manual_seed(seed)
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)

model_from_pretrained.load_adapter(tmp_dirname, "new_adapter")

# check if the state dicts are equal
state_dict = get_peft_model_state_dict(model)
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)

# check if same keys
self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys())

# Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
self.assertEqual(len(list(state_dict.keys())), 4)

# check if tensors equal
for key in state_dict.keys():
self.assertTrue(
torch.allclose(
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
)
)

# check if `adapter_model.bin` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))

# check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))

# check if `pytorch_model.bin` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))

# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

def test_generate(self) -> None:
model = LlamaForCausalLM(self._create_test_llama_config())
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
Expand Down
4 changes: 4 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls
def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs)
Expand Down
60 changes: 60 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,66 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs):
# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model = model.to(self.torch_device)

new_adapter_config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)

model.add_adapter("new_adapter", new_adapter_config)

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)

model_from_pretrained = self.transformers_class.from_pretrained(model_id)
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)

model_from_pretrained.load_adapter(tmp_dirname, "new_adapter")

# check if the state dicts are equal
state_dict = get_peft_model_state_dict(model)
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)

# check if same keys
self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys())

# check if tensors equal
for key in state_dict.keys():
self.assertTrue(
torch.allclose(
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
)
)

# check if `adapter_model.bin` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")))

# check if `adapter_config.json` is present
self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")))

# check if `pytorch_model.bin` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")))

# check if `config.json` is not present
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname, selected_adapters=["default"])

model_from_pretrained = self.transformers_class.from_pretrained(model_id)
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)

self.assertTrue("default" in model_from_pretrained.peft_config.keys())
self.assertTrue("new_adapter" not in model_from_pretrained.peft_config.keys())

def _test_from_pretrained_config_construction(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
config = config_cls(base_model_name_or_path=model_id, **config_kwargs)
Expand Down

0 comments on commit 5a0e19d

Please sign in to comment.