Skip to content

Commit

Permalink
add noisy embedding (axolotl-ai-cloud#721)
Browse files Browse the repository at this point in the history
* add noisy embedding

* fix format

* Update README.md

* Update README.md

* linter issues

* caseus fixes

---------

Co-authored-by: Maxime <maxime@nope.no>
  • Loading branch information
maximegmd and Maxime authored Oct 13, 2023
1 parent aa7ff17 commit d8f86bf
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,11 @@ adam_epsilon:
# Gradient clipping max norm
max_grad_norm:
# Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
noisy_embedding_alpha:
# Whether to bettertransformers
flash_optimum:
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
Expand Down
40 changes: 40 additions & 0 deletions src/axolotl/monkeypatch/llama_embeddings_hijack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""

import torch
import transformers.models.llama.modeling_llama
from transformers.utils import logging

logger = logging.get_logger(__name__)


def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)

return new_func

def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)

return new_func

transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
transformers.models.llama.modeling_llama.LlamaModel.post_init
)
40 changes: 40 additions & 0 deletions src/axolotl/monkeypatch/mistral_embeddings_hijack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""

import torch
import transformers.models.mistral.modeling_mistral
from transformers.utils import logging

logger = logging.get_logger(__name__)


def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)

return new_func

def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)

return new_func

transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
transformers.models.mistral.modeling_mistral.MistralModel.post_init
)
20 changes: 20 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,26 @@ def load_model(
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)

if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)

LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)

if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
)

LOG.info("patching with noisy embeddings")
replace_mistral_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)

if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
Expand Down

0 comments on commit d8f86bf

Please sign in to comment.