From 0ae8360b8a7dda00cb3db223dfee42d27e633907 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Fri, 10 Nov 2023 15:56:01 +0800 Subject: [PATCH] add neftune --- swift/tuners/__init__.py | 2 + swift/tuners/mapping.py | 3 ++ swift/tuners/neftune.py | 90 +++++++++++++++++++++++++++++++++++++++ tests/tuners/test_neft.py | 40 +++++++++++++++++ 4 files changed, 135 insertions(+) create mode 100644 swift/tuners/neftune.py create mode 100644 tests/tuners/test_neft.py diff --git a/swift/tuners/__init__.py b/swift/tuners/__init__.py index 3112e480d..6d77bb3bf 100644 --- a/swift/tuners/__init__.py +++ b/swift/tuners/__init__.py @@ -9,6 +9,7 @@ from .lora import LoRA, LoRAConfig from .mapping import SWIFT_MAPPING, SwiftTuners from .side import Side, SideConfig, SideModule + from .neftune import NEFTune, NEFTuneConfig from .longlora.longlora import LongLoRAModelType, LongLoRAConfig, LongLoRA from .restuning import ResTuning, ResTuningConfig, ResTuningBypassModule from .peft import (LoraConfig, PeftConfig, PeftModel, PeftModelForCausalLM, @@ -29,6 +30,7 @@ ['LongLoRAModelType', 'LongLoRAConfig', 'LongLoRA'], 'mapping': ['SWIFT_MAPPING', 'SwiftTuners'], 'side': ['Side', 'SideConfig', 'SideModule'], + 'neftune': ['NEFTune', 'NEFTuneConfig'], 'restuning': ['ResTuning', 'ResTuningConfig', 'ResTuningBypassModule'], 'peft': [ 'LoraConfig', 'PeftConfig', 'PeftModel', 'PeftModelForCausalLM', diff --git a/swift/tuners/mapping.py b/swift/tuners/mapping.py index b989203d6..ed24842a7 100644 --- a/swift/tuners/mapping.py +++ b/swift/tuners/mapping.py @@ -3,6 +3,7 @@ from .adapter import Adapter, AdapterConfig from .longlora.longlora import LongLoRA, LongLoRAConfig from .lora import LoRA, LoRAConfig +from .neftune import NEFTune, NEFTuneConfig from .prompt import Prompt, PromptConfig from .restuning import ResTuning, ResTuningConfig from .rome import Rome, RomeConfig @@ -17,6 +18,7 @@ class SwiftTuners: RESTUNING = 'RESTUNING' ROME = 'ROME' LONGLORA = 'longlora' + NEFTUNE = 'neftune' SWIFT_MAPPING = { @@ -27,4 +29,5 @@ class SwiftTuners: SwiftTuners.RESTUNING: (ResTuningConfig, ResTuning), SwiftTuners.ROME: (RomeConfig, Rome), SwiftTuners.LONGLORA: (LongLoRAConfig, LongLoRA), + SwiftTuners.NEFTUNE: (NEFTuneConfig, NEFTune), } diff --git a/swift/tuners/neftune.py b/swift/tuners/neftune.py new file mode 100644 index 000000000..75a5061ed --- /dev/null +++ b/swift/tuners/neftune.py @@ -0,0 +1,90 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy +import re +import types +from collections import OrderedDict +from dataclasses import dataclass, field +from functools import partial +from itertools import repeat +from typing import List, Union + +import torch +from torch import nn + +from swift.utils.logger import get_logger +from ..utils.torch_utils import find_sub_module +from .utils import ActivationMixin, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class NEFTuneConfig(SwiftConfig): + """ + The configuration class for the side module. + + Side-Tuning only needs to train one side network and + weights the output of pre-trained model and side network. + 'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks' + by Zhang et al.(2019) + See https://arxiv.org/abs/1912.13503 + + Args: + noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0 + """ + noise_alpha: float = field( + default=5.0, + metadata={ + 'help': + 'The noise alpha value used for the NEFTune' + }) + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.NEFTUNE + + +class NEFTune: + + @staticmethod + def prepare_model(model: nn.Module, config: NEFTuneConfig, + adapter_name: str) -> SwiftOutput: + """Prepare a model with `NEFTuneConfig`""" + for sub_module in model.modules(): + if isinstance(sub_module, torch.nn.Embedding): + def noised_embed(orig_embed, noise_alpha): + def new_func(x): + # during training, we add noise to the embedding + # during generation, we don't add noise to the embedding + if model.training and getattr(orig_embed, 'nef_activated'): + embed_init = orig_embed.forward_origin(x) + dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) + mag_norm = noise_alpha / torch.sqrt(dims) + return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm) + else: + return orig_embed.forward_origin(x) + + return new_func + + if hasattr(sub_module, 'nef_activated'): + raise ValueError(f'NEFTune does not support a second tuner.') + + sub_module.forward_origin = sub_module.forward + sub_module.forward = noised_embed(sub_module, config.noise_alpha) + sub_module.nef_activated = True + + def state_dict_callback(state_dict, adapter_name): + return state_dict + + def mark_trainable_callback(model): + return + + return SwiftOutput(config, state_dict_callback, + mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, + activate: bool): + for sub_module in module.modules(): + if isinstance(sub_module, torch.nn.Embedding): + sub_module.nef_activated = activate diff --git a/tests/tuners/test_neft.py b/tests/tuners/test_neft.py new file mode 100644 index 000000000..070af8cab --- /dev/null +++ b/tests/tuners/test_neft.py @@ -0,0 +1,40 @@ +import os +import shutil +import tempfile +import unittest + +import torch +from modelscope import AutoTokenizer, Model, Preprocessor + +from swift import Swift +from swift.tuners import NEFTuneConfig + + +class TestNEFT(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def test_neft(self): + model = Model.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + preprocessor = Preprocessor.from_pretrained( + 'damo/nlp_structbert_sentence-similarity_chinese-base') + inputs = preprocessor('how are you') + config = NEFTuneConfig() + + t1 = model.encoder.embeddings.word_embeddings(inputs['input_ids']) + model = Swift.prepare_model(model, config) + model.train() + t2 = model.encoder.embeddings.word_embeddings(inputs['input_ids']) + model.deactivate_adapter('default') + t3 = model.encoder.embeddings.word_embeddings(inputs['input_ids']) + self.assertTrue(torch.allclose(t1, t3)) + self.assertFalse(torch.allclose(t1, t2))