Skip to content

Commit

Permalink
add neftune
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Nov 10, 2023
1 parent 5baf263 commit 0ae8360
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 0 deletions.
2 changes: 2 additions & 0 deletions swift/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .lora import LoRA, LoRAConfig
from .mapping import SWIFT_MAPPING, SwiftTuners
from .side import Side, SideConfig, SideModule
from .neftune import NEFTune, NEFTuneConfig
from .longlora.longlora import LongLoRAModelType, LongLoRAConfig, LongLoRA
from .restuning import ResTuning, ResTuningConfig, ResTuningBypassModule
from .peft import (LoraConfig, PeftConfig, PeftModel, PeftModelForCausalLM,
Expand All @@ -29,6 +30,7 @@
['LongLoRAModelType', 'LongLoRAConfig', 'LongLoRA'],
'mapping': ['SWIFT_MAPPING', 'SwiftTuners'],
'side': ['Side', 'SideConfig', 'SideModule'],
'neftune': ['NEFTune', 'NEFTuneConfig'],
'restuning': ['ResTuning', 'ResTuningConfig', 'ResTuningBypassModule'],
'peft': [
'LoraConfig', 'PeftConfig', 'PeftModel', 'PeftModelForCausalLM',
Expand Down
3 changes: 3 additions & 0 deletions swift/tuners/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .adapter import Adapter, AdapterConfig
from .longlora.longlora import LongLoRA, LongLoRAConfig
from .lora import LoRA, LoRAConfig
from .neftune import NEFTune, NEFTuneConfig
from .prompt import Prompt, PromptConfig
from .restuning import ResTuning, ResTuningConfig
from .rome import Rome, RomeConfig
Expand All @@ -17,6 +18,7 @@ class SwiftTuners:
RESTUNING = 'RESTUNING'
ROME = 'ROME'
LONGLORA = 'longlora'
NEFTUNE = 'neftune'


SWIFT_MAPPING = {
Expand All @@ -27,4 +29,5 @@ class SwiftTuners:
SwiftTuners.RESTUNING: (ResTuningConfig, ResTuning),
SwiftTuners.ROME: (RomeConfig, Rome),
SwiftTuners.LONGLORA: (LongLoRAConfig, LongLoRA),
SwiftTuners.NEFTUNE: (NEFTuneConfig, NEFTune),
}
90 changes: 90 additions & 0 deletions swift/tuners/neftune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import re
import types
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from itertools import repeat
from typing import List, Union

import torch
from torch import nn

from swift.utils.logger import get_logger
from ..utils.torch_utils import find_sub_module
from .utils import ActivationMixin, SwiftConfig, SwiftOutput

logger = get_logger()


@dataclass
class NEFTuneConfig(SwiftConfig):
"""
The configuration class for the side module.
Side-Tuning only needs to train one side network and
weights the output of pre-trained model and side network.
'Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks'
by Zhang et al.(2019)
See https://arxiv.org/abs/1912.13503
Args:
noise_alpha(`float`): The noise alpha value used for the NEFTune, default 5.0
"""
noise_alpha: float = field(
default=5.0,
metadata={
'help':
'The noise alpha value used for the NEFTune'
})

def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.NEFTUNE


class NEFTune:

@staticmethod
def prepare_model(model: nn.Module, config: NEFTuneConfig,
adapter_name: str) -> SwiftOutput:
"""Prepare a model with `NEFTuneConfig`"""
for sub_module in model.modules():
if isinstance(sub_module, torch.nn.Embedding):
def noised_embed(orig_embed, noise_alpha):
def new_func(x):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training and getattr(orig_embed, 'nef_activated'):
embed_init = orig_embed.forward_origin(x)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)
else:
return orig_embed.forward_origin(x)

return new_func

if hasattr(sub_module, 'nef_activated'):
raise ValueError(f'NEFTune does not support a second tuner.')

sub_module.forward_origin = sub_module.forward
sub_module.forward = noised_embed(sub_module, config.noise_alpha)
sub_module.nef_activated = True

def state_dict_callback(state_dict, adapter_name):
return state_dict

def mark_trainable_callback(model):
return

return SwiftOutput(config, state_dict_callback,
mark_trainable_callback)

@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str,
activate: bool):
for sub_module in module.modules():
if isinstance(sub_module, torch.nn.Embedding):
sub_module.nef_activated = activate
40 changes: 40 additions & 0 deletions tests/tuners/test_neft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import shutil
import tempfile
import unittest

import torch
from modelscope import AutoTokenizer, Model, Preprocessor

from swift import Swift
from swift.tuners import NEFTuneConfig


class TestNEFT(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

def test_neft(self):
model = Model.from_pretrained(
'damo/nlp_structbert_sentence-similarity_chinese-base')
preprocessor = Preprocessor.from_pretrained(
'damo/nlp_structbert_sentence-similarity_chinese-base')
inputs = preprocessor('how are you')
config = NEFTuneConfig()

t1 = model.encoder.embeddings.word_embeddings(inputs['input_ids'])
model = Swift.prepare_model(model, config)
model.train()
t2 = model.encoder.embeddings.word_embeddings(inputs['input_ids'])
model.deactivate_adapter('default')
t3 = model.encoder.embeddings.word_embeddings(inputs['input_ids'])
self.assertTrue(torch.allclose(t1, t3))
self.assertFalse(torch.allclose(t1, t2))

0 comments on commit 0ae8360

Please sign in to comment.