diff --git a/requirements.txt b/requirements.txt index cd5690f0b2..c3a9eb2f38 100644 --- a/requirements.txt +++ b/requirements.txt @@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2 torchao==0.7.0 schedulefree==1.3.0 -axolotl-contribs-lgpl==0.0.3 +axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 178b90f7b6..1ceb5babd0 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -7,7 +7,7 @@ import sys import weakref from pathlib import Path -from typing import Any +from typing import Any, Dict import torch import transformers.modelcard @@ -20,7 +20,7 @@ from transformers.trainer import Trainer from axolotl.common.datasets import TrainDatasetMeta -from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module +from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder @@ -382,21 +382,23 @@ def handle_untrained_tokens_fix( if not cfg.fix_untrained_tokens: return + is_ds_zero3: bool = False + if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": + is_ds_zero3 = True + # Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args sig = inspect.signature(fix_untrained_tokens) + fix_kwargs: Dict[str, Any] = {} # If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list if "token_ids_to_fix" in sig.parameters and isinstance( cfg.fix_untrained_tokens, list ): - fix_untrained_tokens( - model, - tokenizer, - train_dataset, - token_ids_to_fix=cfg.fix_untrained_tokens, - ) - else: - fix_untrained_tokens(model, tokenizer, train_dataset) + fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens + if "is_ds_zero3" in sig.parameters: + fix_kwargs["is_ds_zero3"] = is_ds_zero3 + + fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs) if cfg.local_rank == 0: model.save_pretrained( diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 0f91fe056f..60b1940907 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -750,3 +750,66 @@ def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora): check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" ) + + def test_fix_untrained_tokens(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "fix_untrained_tokens": True, + "sequence_len": 512, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + "bos_token": "<|custom_im_start|>", + "eos_token": "<|custom_im_end|>", + }, + "datasets": [ + { + "chat_template": "jinja", + "chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}", + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"), + "use_tensorboard": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high" + ) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 77e70d8c24..6447442404 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -66,6 +66,54 @@ def test_fft_trust_remote_code(self, temp_dir): check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "fix_untrained_tokens": True, + "sequence_len": 512, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + "bos_token": "<|custom_im_start|>", + "eos_token": "<|custom_im_end|>", + }, + "datasets": [ + { + "chat_template": "jinja", + "chat_template_jinja": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|custom_im_start|>' + message['role'] + '\n' + message['content'] + '<|custom_im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|custom_im_start|>assistant\n' }}{% endif %}", + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + def test_fix_untrained_tokens_already_trained(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( {