|
30 | 30 | from transformers.pytorch_utils import Conv1D
|
31 | 31 |
|
32 | 32 | from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
|
33 |
| -from peft.utils.constants import DUMMY_TARGET_MODULES |
34 |
| -from peft.utils.peft_types import PeftType |
| 33 | +from peft.utils.constants import DUMMY_TARGET_MODULES, SEQ_CLS_HEAD_NAMES |
| 34 | +from peft.utils.peft_types import PeftType, TaskType |
35 | 35 |
|
36 | 36 | from ..config import PeftConfig
|
37 | 37 | from ..utils import ModulesToSaveWrapper, _get_submodules
|
@@ -812,11 +812,25 @@ def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module)
|
812 | 812 | names = name.rsplit(".", 1)[-1] # get the base name
|
813 | 813 | linear_module_names.add(names)
|
814 | 814 |
|
815 |
| - # ignore the last classification head for text generation models |
| 815 | + # Try to remove linear layers that should not be targeted as best as possible. We have to rely on convention as |
| 816 | + # there are no hard rules to detect these modules. |
| 817 | + module_names_to_exclude = set() |
816 | 818 | output_emb = model.get_output_embeddings()
|
817 | 819 | if output_emb is not None:
|
| 820 | + # ignore the last classification head for text generation models |
818 | 821 | last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
|
819 |
| - linear_module_names -= {last_module_name} |
| 822 | + module_names_to_exclude.add(last_module_name) |
| 823 | + elif peft_config.task_type == TaskType.SEQ_CLS: |
| 824 | + # ignore classifier head for classification models (issue 2027) |
| 825 | + # there is no fix name for the classifier head, so check the common ones |
| 826 | + for name in SEQ_CLS_HEAD_NAMES: |
| 827 | + cls_head = getattr(model, name, None) |
| 828 | + if cls_head is not None: |
| 829 | + last_module_name = [name for name, module in model.named_modules() if module is cls_head][0] |
| 830 | + module_names_to_exclude.add(last_module_name) |
| 831 | + break |
| 832 | + |
| 833 | + linear_module_names -= module_names_to_exclude |
820 | 834 | peft_config.target_modules = linear_module_names
|
821 | 835 | return peft_config
|
822 | 836 |
|
|
0 commit comments