From 02dfbfdde51b3f9ad7d8caa76c62d1950e3e0c9e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jan 2024 12:19:24 -0500 Subject: [PATCH 1/2] be more robust about checking embedding modules for lora finetunes --- src/axolotl/utils/config.py | 18 ++----- src/axolotl/utils/lora_embeddings.py | 12 +++++ src/axolotl/utils/models.py | 29 +++++++++--- tests/test_validation.py | 70 ++++++++++++++++++++++++---- 4 files changed, 100 insertions(+), 29 deletions(-) create mode 100644 src/axolotl/utils/lora_embeddings.py diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 4d4da18ba..9a69184d8 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -151,6 +151,10 @@ def normalize_config(cfg): def validate_config(cfg): + """ + This is a "pre-validation" step that handles the yaml configuration before we have any + information about the model architecture + """ if is_torch_bf16_gpu_available(): if not cfg.bf16 and not cfg.bfloat16: LOG.info("bf16 support detected, but not enabled for this configuration.") @@ -443,20 +447,6 @@ def validate_config(cfg): if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: raise ValueError("neftune_noise_alpha must be > 0.0") - if ( - cfg.adapter - and cfg.tokens - and ( - not cfg.lora_modules_to_save - or not all( - x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"] - ) - ) - ): - raise ValueError( - "lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`." - ) - if cfg.max_memory is not None and cfg.gpu_memory_limit is not None: raise ValueError( "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." diff --git a/src/axolotl/utils/lora_embeddings.py b/src/axolotl/utils/lora_embeddings.py new file mode 100644 index 000000000..f9ea91727 --- /dev/null +++ b/src/axolotl/utils/lora_embeddings.py @@ -0,0 +1,12 @@ +""" +helpers for lora embeddings +""" + + +def get_linear_embedding_layers(model_type): + """ + returns the linear embedding layers needed for loras, dependent on the model arch + """ + if model_type == "phi-msft": + return ["embd", "lm_head.linear"] + return ["lm_head", "embed_tokens"] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 98d1e7c48..7d13709d7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -2,7 +2,7 @@ import logging import math import os -from typing import Any, Optional, Tuple # noqa: F401 +from typing import Any, Optional, Tuple, Union # noqa: F401 import addict import bitsandbytes as bnb @@ -28,12 +28,16 @@ from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates from axolotl.utils.dict import DictDefault +from axolotl.utils.lora_embeddings import get_linear_embedding_layers LOG = logging.getLogger("axolotl") -def check_model_config(cfg: DictDefault, model_config: AutoConfig): - quant_config_exists = hasattr(model_config, "quantization_config") +def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): + quant_config_exists = ( + hasattr(model_config, "quantization_config") + and model_config.quantization_config + ) quant_config_method_is_gptq = ( quant_config_exists and "quant_method" in model_config.quantization_config @@ -52,6 +56,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig): "Please use the `gptq` flag to train quantized model or point to a non-quantized model." ) + lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) + if ( + cfg.adapter + and cfg.tokens + and ( + not cfg.lora_modules_to_save + or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save) + ) + ): + lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save)) + raise ValueError( + f"`lora_modules_to_save` not properly set when adding new tokens. Please include {lora_modules_to_save} in `lora_modules_to_save`." + ) + def load_model_config(cfg): model_config_name = cfg.base_model_config or cfg.base_model @@ -139,6 +157,7 @@ def load_tokenizer(cfg): setattr(tokenizer, attr_name, "<|endoftext|>") if cfg.special_tokens: + lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) for k, val in cfg.special_tokens.items(): # check if new special token is not already in tokenizer and # is adapter training to make sure lora_modules_to_save is set @@ -149,11 +168,9 @@ def load_tokenizer(cfg): and ( not cfg.lora_modules_to_save or not all( - x in cfg.lora_modules_to_save - for x in ["embed_tokens", "lm_head"] + x in cfg.lora_modules_to_save for x in lora_modules_to_save ) ) - and (model_config.model_type in ["llama", "mistral", "mixtral"]) ): raise ValueError( "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens." diff --git a/tests/test_validation.py b/tests/test_validation.py index d2518a7df..c952b7fcf 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -10,12 +10,13 @@ from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.models import check_model_config from axolotl.utils.wandb_ import setup_wandb_env_vars -class ValidationTest(unittest.TestCase): +class BaseValidation(unittest.TestCase): """ - Test the validation module + Base validation module to setup the log capture """ _caplog: Optional[pytest.LogCaptureFixture] = None @@ -24,6 +25,12 @@ class ValidationTest(unittest.TestCase): def inject_fixtures(self, caplog): self._caplog = caplog + +class ValidationTest(BaseValidation): + """ + Test the validation module + """ + def test_load_4bit_deprecate(self): cfg = DictDefault( { @@ -687,16 +694,23 @@ def test_warmup_step_no_conflict(self): validate_config(cfg) - def test_add_tokens_adapter(self): + +class ValidationCheckModelConfig(BaseValidation): + """ + Test the validation for the config when the model config is available + """ + + def test_llama_add_tokens_adapter(self): cfg = DictDefault( {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} ) + model_config = DictDefault({"model_type": "llama"}) with pytest.raises( ValueError, - match=r".*lora_modules_to_save not properly set yet adding new tokens*", + match=r".*`lora_modules_to_save` not properly set when adding new tokens*", ): - validate_config(cfg) + check_model_config(cfg, model_config) cfg = DictDefault( { @@ -709,9 +723,9 @@ def test_add_tokens_adapter(self): with pytest.raises( ValueError, - match=r".*lora_modules_to_save not properly set yet adding new tokens*", + match=r".*`lora_modules_to_save` not properly set when adding new tokens*", ): - validate_config(cfg) + check_model_config(cfg, model_config) cfg = DictDefault( { @@ -722,10 +736,48 @@ def test_add_tokens_adapter(self): } ) - validate_config(cfg) + check_model_config(cfg, model_config) + + def test_phi2_add_tokens_adapter(self): + cfg = DictDefault( + {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} + ) + model_config = DictDefault({"model_type": "phi-msft"}) + + with pytest.raises( + ValueError, + match=r".*`lora_modules_to_save` not properly set when adding new tokens*", + ): + check_model_config(cfg, model_config) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embed_tokens", "lm_head"], + } + ) + + with pytest.raises( + ValueError, + match=r".*`lora_modules_to_save` not properly set when adding new tokens*", + ): + check_model_config(cfg, model_config) + + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "tokens": ["<|imstart|>"], + "lora_modules_to_save": ["embd", "lm_head.linear"], + } + ) + + check_model_config(cfg, model_config) -class ValidationWandbTest(ValidationTest): +class ValidationWandbTest(BaseValidation): """ Validation test for wandb """ From 2d8627159b4a2299d990c4289a4379f019c635dc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jan 2024 21:37:31 -0500 Subject: [PATCH 2/2] update dynamic error message --- src/axolotl/utils/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7d13709d7..0e7633a3b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -172,8 +172,11 @@ def load_tokenizer(cfg): ) ) ): + lora_modules_to_save = ", ".join( + [f"`{x}`" for x in lora_modules_to_save] + ) raise ValueError( - "Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens." + f"Please set lora_modules_to_save to {lora_modules_to_save} when using an adapter and changing the special tokens." ) tokenizer.add_special_tokens(