Skip to content

Commit

Permalink
Refactor landmark attention patch
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Jun 11, 2023
1 parent d9f713e commit 919727b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
9 changes: 9 additions & 0 deletions src/axolotl/monkeypatch/llama_landmark_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id):
ret.extend(x[prev_idx:])
# drop attention_mask
return {"input_ids": ret}


def patch_llama_with_landmark_attn():
import transformers

transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
23 changes: 11 additions & 12 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@
LlamaConfig,
)

try:
from transformers import ( # pylint: disable=unused-import # noqa: F401
LlamaForCausalLM,
)
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)

from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN

if TYPE_CHECKING:
Expand Down Expand Up @@ -118,14 +109,15 @@ def load_model(
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
from axolotl.monkeypatch.llama_landmark_attn import (
MEM_TOKEN,
LlamaForCausalLM,
patch_llama_with_landmark_attn,
)

logging.info("patching with landmark attention")
patch_llama_with_landmark_attn()

# TODO: Check if this would overwrite previous additional_special_tokens
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})

if cfg.is_llama_derived_model and cfg.xpos_rope:
Expand Down Expand Up @@ -211,6 +203,13 @@ def load_model(
)
load_in_8bit = False
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
try:
from transformers import LlamaForCausalLM
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)

config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
Expand Down

0 comments on commit 919727b

Please sign in to comment.