From 6a237208115c8c1dc2ffe6a46c9873bb770269fb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Jan 2024 18:50:36 -0500 Subject: [PATCH 01/23] cleanup dpo to be a little more extensible, add zephyr/nectar strategy --- src/axolotl/cli/__init__.py | 79 +------------------ src/axolotl/prompt_strategies/dpo/__init__.py | 17 ++++ src/axolotl/prompt_strategies/dpo/chatml.py | 77 ++++++++++++++++++ src/axolotl/prompt_strategies/dpo/zephyr.py | 21 +++++ src/axolotl/utils/data.py | 25 +++++- 5 files changed, 141 insertions(+), 78 deletions(-) create mode 100644 src/axolotl/prompt_strategies/dpo/__init__.py create mode 100644 src/axolotl/prompt_strategies/dpo/chatml.py create mode 100644 src/axolotl/prompt_strategies/dpo/zephyr.py diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 15da78b09..9aab7b39f 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -17,7 +17,6 @@ # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art -from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer @@ -30,7 +29,7 @@ normalize_config, validate_config, ) -from axolotl.utils.data import prepare_dataset +from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.mlflow_ import setup_mlflow_env_vars @@ -343,81 +342,7 @@ def load_rl_datasets( cfg: DictDefault, cli_args: TrainerCliArgs, # pylint: disable=unused-argument ) -> TrainDatasetMeta: - train_datasets: List[Any] = [] - for i, ds_cfg in enumerate(cfg.datasets): - train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"])) - # eval_dataset = load_dataset( - # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"] - # ) - eval_dataset = None - - def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" - sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" - return sample - - def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen']}<|im_end|>" - sample["rejected"] = f"{sample['rejected']}<|im_end|>" - return sample - - def apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen']}<|im_end|>" - sample["rejected"] = f"{sample['rejected']}<|im_end|>" - return sample - - def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable - if "system" in sample and sample["system"]: - sample["prompt"] = ( - f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - ) - else: - sample[ - "prompt" - ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" - sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" - return sample - - for i, data_set in enumerate(train_datasets): - _type = cfg.datasets[i]["type"] - ds_type_fn = locals()[_type] - train_datasets[i] = data_set.map( - ds_type_fn, - desc="Mapping RL Dataset", - ) - train_dataset = concatenate_datasets(train_datasets) - - # eval_dataset = eval_dataset.map(intel_apply_chatml) - + train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py new file mode 100644 index 000000000..51be84866 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/__init__.py @@ -0,0 +1,17 @@ +""" +module for DPO style dataset transform strategies +""" + +import importlib + + +def load(strategy, cfg): + try: + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo") + func = getattr(mod, load_fn) + load_kwargs = {} + return func(cfg, **load_kwargs) + except Exception: # pylint: disable=broad-exception-caught + return None diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py new file mode 100644 index 000000000..12e869cf7 --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -0,0 +1,77 @@ +""" +DPO strategies for chatml +""" + + +def argilla_apply_chatml( + cfg, +): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" + sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + return sample + + return transform_fn + + +def intel_apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + return transform_fn + + +def apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + return transform_fn + + +def ultra_apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + return transform_fn diff --git a/src/axolotl/prompt_strategies/dpo/zephyr.py b/src/axolotl/prompt_strategies/dpo/zephyr.py new file mode 100644 index 000000000..71a920beb --- /dev/null +++ b/src/axolotl/prompt_strategies/dpo/zephyr.py @@ -0,0 +1,21 @@ +""" +DPO strategies for zephyr +""" + + +def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument + def transform_fn(sample): + data = {} + data["prompt"] = ( + "<|system|>\n\n" + "<|user|>\n" + f"{sample['prompt']}<\\s>\n" + "<|assistant|>\n" + ) + answers = sorted(sample["answers"], key=lambda x: x["rank"]) + data["chosen"] = answers[-1]["answer"] + data["rejected"] = answers[-2]["answer"] + + return data + + return transform_fn diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5c4cd148b..422a1adeb 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from datasets import ( @@ -21,6 +21,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load +from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, @@ -850,3 +851,25 @@ def encode_packed_pretraining( chunked_data[feature].append(collated_features[feature].squeeze(0)) return chunked_data + + +def load_prepare_dpo_datasets(cfg): + train_datasets: List[Any] = [] + eval_dataset = None + + for i, ds_cfg in enumerate(cfg.datasets): + ds = load_dataset( # pylint: disable=invalid-name + ds_cfg["path"], split=ds_cfg["split"] + ) + train_datasets.insert(i, ds) + + for i, data_set in enumerate(train_datasets): + _type = cfg.datasets[i]["type"] + ds_transform_fn = load_dpo(_type, cfg) + train_datasets[i] = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) + train_dataset = concatenate_datasets(train_datasets) + + return train_dataset, eval_dataset From 2f7242d1f5e3589bca6f051aa86f6e23503daa9e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 12 Jan 2024 21:25:43 -0500 Subject: [PATCH 02/23] fix eos slash --- src/axolotl/prompt_strategies/dpo/zephyr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/dpo/zephyr.py b/src/axolotl/prompt_strategies/dpo/zephyr.py index 71a920beb..02bce8a33 100644 --- a/src/axolotl/prompt_strategies/dpo/zephyr.py +++ b/src/axolotl/prompt_strategies/dpo/zephyr.py @@ -9,7 +9,7 @@ def transform_fn(sample): data["prompt"] = ( "<|system|>\n\n" "<|user|>\n" - f"{sample['prompt']}<\\s>\n" + f"{sample['prompt']}\n" "<|assistant|>\n" ) answers = sorted(sample["answers"], key=lambda x: x["rank"]) From 74cbe2a1232a0101e1b40ac1309b3e23aa94bba7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jan 2024 06:32:41 -0500 Subject: [PATCH 03/23] support for eval split --- src/axolotl/core/trainer_builder.py | 11 +++++---- src/axolotl/utils/data.py | 37 +++++++++++++++++------------ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c3b01e6c6..fab2cf353 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -949,14 +949,17 @@ def build_training_arguments(self, total_num_steps): ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + if self.eval_dataset: + training_args_kwargs["evaluation_strategy"] = "steps" + training_args_kwargs["eval_steps"] = self.cfg.eval_steps + else: + training_args_kwargs["evaluation_strategy"] = "no" training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=total_num_steps, remove_unused_columns=False, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, - evaluation_strategy="no", - # eval_steps=self.cfg.eval_steps, save_strategy="steps", save_steps=self.cfg.save_steps, output_dir=self.cfg.output_dir, @@ -982,14 +985,14 @@ def build(self, total_num_steps): dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing elif self.cfg.rl == "kto_pair": dpo_trainer_kwargs["loss_type"] = "kto_pair" - + if self.eval_dataset: + dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset dpo_trainer = DPOTrainer( self.model, self.model_ref, args=training_args, beta=self.cfg.dpo_beta or 0.1, train_dataset=self.train_dataset, - # eval_dataset=self.eval_dataset, eval_dataset=None, tokenizer=self.tokenizer, max_length=self.cfg.sequence_len, diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 422a1adeb..5d154bb10 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -854,22 +854,29 @@ def encode_packed_pretraining( def load_prepare_dpo_datasets(cfg): - train_datasets: List[Any] = [] - eval_dataset = None + def load_split(dataset_cfgs, _cfg): + split_datasets: List[Any] = [] + for i, ds_cfg in enumerate(dataset_cfgs): + ds = load_dataset( # pylint: disable=invalid-name + ds_cfg["path"], + split=ds_cfg["split"], + desc="Mapping RL Dataset", + ) + split_datasets.insert(i, ds) + + for i, data_set in enumerate(split_datasets): + _type = dataset_cfgs[i]["type"] + ds_transform_fn = load_dpo(_type, _cfg) + split_datasets[i] = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) - for i, ds_cfg in enumerate(cfg.datasets): - ds = load_dataset( # pylint: disable=invalid-name - ds_cfg["path"], split=ds_cfg["split"] - ) - train_datasets.insert(i, ds) + return concatenate_datasets(split_datasets) - for i, data_set in enumerate(train_datasets): - _type = cfg.datasets[i]["type"] - ds_transform_fn = load_dpo(_type, cfg) - train_datasets[i] = data_set.map( - ds_transform_fn, - desc="Mapping RL Dataset", - ) - train_dataset = concatenate_datasets(train_datasets) + train_dataset = load_split(cfg.datasets, cfg) + eval_dataset = load_split(cfg.test_datasets, cfg) + if not eval_dataset: + eval_dataset = None return train_dataset, eval_dataset From 1f2832732f84db0461f7acb56dc2fa6e7aa4dd37 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jan 2024 06:36:14 -0500 Subject: [PATCH 04/23] fix kwargs --- src/axolotl/core/trainer_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fab2cf353..f17b85837 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -993,7 +993,6 @@ def build(self, total_num_steps): args=training_args, beta=self.cfg.dpo_beta or 0.1, train_dataset=self.train_dataset, - eval_dataset=None, tokenizer=self.tokenizer, max_length=self.cfg.sequence_len, max_target_length=None, From cb2f774ea1017f7a0052a9a13fd847f3ff9ed489 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jan 2024 08:28:54 -0500 Subject: [PATCH 05/23] handle empty evals --- src/axolotl/utils/data.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 5d154bb10..b89edaa40 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -875,7 +875,10 @@ def load_split(dataset_cfgs, _cfg): return concatenate_datasets(split_datasets) train_dataset = load_split(cfg.datasets, cfg) - eval_dataset = load_split(cfg.test_datasets, cfg) + + eval_dataset = None + if cfg.test_datasets: + eval_dataset = load_split(cfg.test_datasets, cfg) if not eval_dataset: eval_dataset = None From a91e0cb318d8ea3b1d690633e45010a44c430e8f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jan 2024 11:20:10 -0500 Subject: [PATCH 06/23] don't load peft model for dpo --- src/axolotl/core/trainer_builder.py | 11 +++++++++++ src/axolotl/train.py | 7 ++++++- src/axolotl/utils/models.py | 14 +++++++++++--- src/axolotl/utils/trainer.py | 3 ++- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f17b85837..76a6b36d1 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -460,6 +460,7 @@ class TrainerBuilderBase(abc.ABC): _train_dataset = None _eval_dataset = None _model_ref = None + _peft_config = None def __init__(self, cfg, model, tokenizer): self.cfg = cfg @@ -490,6 +491,14 @@ def eval_dataset(self): def eval_dataset(self, dataset): self._eval_dataset = dataset + @property + def peft_config(self): + return self._peft_config + + @peft_config.setter + def peft_config(self, peft_config): + self._peft_config = peft_config + @abstractmethod def build(self, total_num_steps): pass @@ -987,6 +996,8 @@ def build(self, total_num_steps): dpo_trainer_kwargs["loss_type"] = "kto_pair" if self.eval_dataset: dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset + if self.cfg.adapter and self.peft_config: + dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer = DPOTrainer( self.model, self.model_ref, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d68ae46b1..79b880234 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -96,7 +96,12 @@ def train( freeze_parameters_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( - cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps + cfg, + train_dataset, + eval_dataset, + (model, model_ref, peft_config), + tokenizer, + total_num_steps, ) if hasattr(model, "config"): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d75926952..c8da7b514 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -682,7 +682,12 @@ def load_model( lora_config = None if not reference_model or cfg.lora_model_dir: - model, lora_config = load_adapter(model, cfg, cfg.adapter) + # if we're not loading the reference model, then we're loading the model for training + # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config + if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"]: + _, lora_config = load_lora(model, cfg, inference=False, config_only=True) + else: + model, lora_config = load_adapter(model, cfg, cfg.adapter) if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): model.to(f"cuda:{cfg.local_rank}") @@ -770,8 +775,8 @@ def find_all_linear_names(model): return list(lora_module_names) -def load_lora(model, cfg, inference=False): - # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] +def load_lora(model, cfg, inference=False, config_only=False): + # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] from peft import LoraConfig, PeftModel, get_peft_model @@ -794,6 +799,9 @@ def load_lora(model, cfg, inference=False): task_type="CAUSAL_LM", ) + if config_only: + return None, lora_config + if cfg.lora_model_dir: LOG.debug("Loading pretained PEFT - LoRA") model_kwargs: Any = {} diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index dfd3385b7..2e9d782c7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -316,9 +316,10 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl: + if cfg.rl in ["dpo", "ipo", "kto_pair"]: trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] + trainer_builder.peft_config = model[2] else: trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) From bcae290b2d2a70ffa4c9ffdd29f1d29e893ca555 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jan 2024 11:29:47 -0500 Subject: [PATCH 07/23] ensure dpo traning args gets bf16 for peft if applicable --- src/axolotl/core/trainer_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 76a6b36d1..6bf91e544 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -963,6 +963,8 @@ def build_training_arguments(self, total_num_steps): training_args_kwargs["eval_steps"] = self.cfg.eval_steps else: training_args_kwargs["evaluation_strategy"] = "no" + if self.cfg.bf16 or self.cfg.bfloat16: + training_args_kwargs["bf16"] = True training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=total_num_steps, From a24c756f6e5c610fea9a486c9cb40c2ac7680172 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jan 2024 11:33:07 -0500 Subject: [PATCH 08/23] fix duplicate kwargs for bf16 --- src/axolotl/core/trainer_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6bf91e544..9d26bcb91 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -975,7 +975,6 @@ def build_training_arguments(self, total_num_steps): save_steps=self.cfg.save_steps, output_dir=self.cfg.output_dir, warmup_steps=self.cfg.warmup_steps, - bf16=True, gradient_checkpointing=self.cfg.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, logging_first_step=True, From d5e12ddb8a76babefb5046a8bfa7011e87427d92 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 15 Jan 2024 16:25:48 -0500 Subject: [PATCH 09/23] make sure to respect the configured lr scheduler --- src/axolotl/core/trainer_builder.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9d26bcb91..4e824bbf4 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -958,6 +958,7 @@ def build_training_arguments(self, total_num_steps): ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + if self.eval_dataset: training_args_kwargs["evaluation_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps @@ -965,6 +966,14 @@ def build_training_arguments(self, total_num_steps): training_args_kwargs["evaluation_strategy"] = "no" if self.cfg.bf16 or self.cfg.bfloat16: training_args_kwargs["bf16"] = True + + training_args_kwargs["lr_scheduler_type"] = ( + self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine" + ) + training_args_kwargs["lr_scheduler_kwargs"] = ( + self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} + ) + training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=total_num_steps, From 9a746fa698a5ed819cb630ac2c4e564d41232ef1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 16 Jan 2024 07:01:00 -0500 Subject: [PATCH 10/23] supprt trainer callback to push config to wandb --- src/axolotl/core/trainer_builder.py | 44 +++++++++++++++++++---------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4e824bbf4..269a760dc 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -12,14 +12,19 @@ from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Optional, Type, Union +from typing import List, Optional, Type, Union import torch import transformers from datasets import Dataset from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler -from transformers import EarlyStoppingCallback, Trainer, TrainingArguments +from transformers import ( + EarlyStoppingCallback, + Trainer, + TrainerCallback, + TrainingArguments, +) from transformers.trainer_utils import seed_worker from trl import DPOTrainer @@ -503,9 +508,14 @@ def peft_config(self, peft_config): def build(self, total_num_steps): pass - @abstractmethod - def get_callbacks(self): - pass + def get_callbacks(self) -> List[TrainerCallback]: + callbacks = [] + if self.cfg.use_wandb: + callbacks.append( + SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) + ) + + return callbacks @abstractmethod def get_post_trainer_create_callbacks(self, trainer): @@ -513,12 +523,6 @@ def get_post_trainer_create_callbacks(self, trainer): Callbacks added after the trainer is created, usually b/c these need access to the trainer """ - -class HFCausalTrainerBuilder(TrainerBuilderBase): - """ - Build the HuggingFace training args/trainer for Causal models - """ - def hook_pre_create_training_args(self, training_arguments_kwargs): # TODO return training_arguments_kwargs @@ -535,10 +539,16 @@ def hook_post_create_trainer(self, trainer): # TODO return trainer + +class HFCausalTrainerBuilder(TrainerBuilderBase): + """ + Build the HuggingFace training args/trainer for Causal models + """ + def get_callbacks(self): - callbacks = [] + callbacks = super().get_callbacks() callbacks.append(GPUStatsCallback(self.cfg)) - callbacks.append(EvalFirstStepCallback) + callbacks.append(EvalFirstStepCallback()) if self.cfg.relora_steps: callbacks.append(ReLoRACallback(self.cfg)) @@ -547,7 +557,7 @@ def get_callbacks(self): hasattr(self.model, "use_bettertransformer") and self.model.use_bettertransformer is True ): - callbacks.append(SaveBetterTransformerModelCallback) + callbacks.append(SaveBetterTransformerModelCallback()) if self.cfg.use_wandb: callbacks.append( @@ -940,7 +950,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): """ def get_callbacks(self): - callbacks = [] + callbacks = super().get_callbacks() return callbacks def get_post_trainer_create_callbacks(self, trainer): @@ -1019,8 +1029,12 @@ def build(self, total_num_steps): max_target_length=None, max_prompt_length=self.cfg.sequence_len, generate_during_eval=True, + callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) + dpo_trainer = self.hook_post_create_trainer(dpo_trainer) + for callback in self.get_post_trainer_create_callbacks(dpo_trainer): + dpo_trainer.add_callback(callback) return dpo_trainer From 60f566c2a70b9f91bc36fb4c43014d54de0287eb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 16 Jan 2024 07:29:19 -0500 Subject: [PATCH 11/23] set dataloader preload args --- src/axolotl/core/trainer_builder.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 269a760dc..bad3f96a0 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -969,6 +969,18 @@ def build_training_arguments(self, total_num_steps): if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + if self.cfg.hub_model_id: + training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id + training_args_kwargs["push_to_hub"] = True + training_args_kwargs["hub_private_repo"] = True + training_args_kwargs["hub_always_push"] = True + + if self.cfg.hub_strategy: + training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy + + if self.cfg.save_safetensors is not None: + training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors + if self.eval_dataset: training_args_kwargs["evaluation_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps @@ -984,6 +996,19 @@ def build_training_arguments(self, total_num_steps): self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} ) + if self.cfg.dataloader_pin_memory is not None: + training_args_kwargs[ + "dataloader_pin_memory" + ] = self.cfg.dataloader_pin_memory + if self.cfg.dataloader_num_workers is not None: + training_args_kwargs[ + "dataloader_num_workers" + ] = self.cfg.dataloader_num_workers + if self.cfg.dataloader_prefetch_factor is not None: + training_args_kwargs[ + "dataloader_prefetch_factor" + ] = self.cfg.dataloader_prefetch_factor + training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=total_num_steps, From c41391d87af6a761622b512d9aeec25f26f097fc Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jan 2024 13:08:05 -0500 Subject: [PATCH 12/23] ensure that we are loading the lora when merging --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c8da7b514..6ba1e3704 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -684,7 +684,7 @@ def load_model( if not reference_model or cfg.lora_model_dir: # if we're not loading the reference model, then we're loading the model for training # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"]: + if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: _, lora_config = load_lora(model, cfg, inference=False, config_only=True) else: model, lora_config = load_adapter(model, cfg, cfg.adapter) From 1cfd1796615dd2ffda20b938ad8e08f77d873ff5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 18 Jan 2024 17:36:02 -0500 Subject: [PATCH 13/23] Update src/axolotl/utils/data.py Co-authored-by: Agus --- src/axolotl/utils/data.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index b89edaa40..bb4f52b38 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -866,11 +866,16 @@ def load_split(dataset_cfgs, _cfg): for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] - ds_transform_fn = load_dpo(_type, _cfg) - split_datasets[i] = data_set.map( - ds_transform_fn, - desc="Mapping RL Dataset", - ) + if _type: + ds_transform_fn = load_dpo(_type, _cfg) + split_datasets[i] = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) + else: + # If no `type` is provided, assume the dataset is already in the expected format with + # "prompt", "chosen" and "rejected" already preprocessed + split_datasets[i] = data_set return concatenate_datasets(split_datasets) From 02fc8f925d47ddb7db6543254c8556fa9fce343b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 20:00:52 -0500 Subject: [PATCH 14/23] support local datasets for dpo Co-authored-by: Agus --- src/axolotl/utils/data.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index bb4f52b38..264883dd9 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -857,12 +857,23 @@ def load_prepare_dpo_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] for i, ds_cfg in enumerate(dataset_cfgs): - ds = load_dataset( # pylint: disable=invalid-name - ds_cfg["path"], - split=ds_cfg["split"], - desc="Mapping RL Dataset", - ) - split_datasets.insert(i, ds) + if ds_cfg["ds_type"] == "json": + for data_file in ds_cfg["data_files"]: + data_files = {ds_cfg["split"]: data_file} + ds = load_dataset( + "json", + data_files=data_files, + split=ds_cfg["split"], + desc="Mapping RL Dataset", + ) + split_datasets.insert(i, ds) + else: + ds = load_dataset( # pylint: disable=invalid-name + ds_cfg["path"], + split=ds_cfg["split"], + desc="Mapping RL Dataset", + ) + split_datasets.insert(i, ds) for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] From 7141fd12bb20e9b9ce8d098470d6d6a8ae418107 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 20:12:41 -0500 Subject: [PATCH 15/23] chore: lint --- src/axolotl/utils/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 264883dd9..f3919fbb6 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -860,7 +860,7 @@ def load_split(dataset_cfgs, _cfg): if ds_cfg["ds_type"] == "json": for data_file in ds_cfg["data_files"]: data_files = {ds_cfg["split"]: data_file} - ds = load_dataset( + ds = load_dataset( # pylint: disable=invalid-name "json", data_files=data_files, split=ds_cfg["split"], From 44a6f2db21042986d6bceaaaa125db412f5ed4da Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 21:19:43 -0500 Subject: [PATCH 16/23] dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names --- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/prompt_strategies/dpo/chatml.py | 16 ++- tests/e2e/test_dpo.py | 145 ++++++++++++++++++++ 3 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 tests/e2e/test_dpo.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index bad3f96a0..75da7905f 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1011,7 +1011,7 @@ def build_training_arguments(self, total_num_steps): training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, - max_steps=total_num_steps, + max_steps=self.cfg.max_steps or total_num_steps, remove_unused_columns=False, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 12e869cf7..e0840f762 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -3,7 +3,7 @@ """ -def argilla_apply_chatml( +def argilla( cfg, ): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): @@ -23,7 +23,11 @@ def transform_fn(sample): return transform_fn -def intel_apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument +def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument + """ + For Intel Orca DPO Pairs + """ + def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -41,7 +45,7 @@ def transform_fn(sample): return transform_fn -def apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument +def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( @@ -59,7 +63,11 @@ def transform_fn(sample): return transform_fn -def ultra_apply_chatml(cfg): # pylint: disable=possibly-unused-variable,unused-argument +def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument + """ + for ultrafeedback binarized conversations + """ + def transform_fn(sample): if "system" in sample and sample["system"]: sample["prompt"] = ( diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py new file mode 100644 index 000000000..88dd1540a --- /dev/null +++ b/tests/e2e/test_dpo.py @@ -0,0 +1,145 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_rl_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestDPOLlamaLora(unittest.TestCase): + """ + Test case for DPO Llama models using LoRA + """ + + @with_temp_dir + def test_dpo_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "dpo", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_kto_pair_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "kto_pair", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_ipo_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "ipo", + "datasets": [ + { + "path": "Intel/orca_dpo_pairs", + "type": "chatml.intel", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() From 7f3b7cef4c05bf7f7f6e28a2ecd871dc44fae1f0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 21:39:39 -0500 Subject: [PATCH 17/23] add split to dpo tests --- tests/e2e/test_dpo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 88dd1540a..0ab042eea 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -44,6 +44,7 @@ def test_dpo_lora(self, temp_dir): { "path": "Intel/orca_dpo_pairs", "type": "chatml.intel", + "split": "train", }, ], "num_epochs": 1, @@ -84,6 +85,7 @@ def test_kto_pair_lora(self, temp_dir): { "path": "Intel/orca_dpo_pairs", "type": "chatml.intel", + "split": "train", }, ], "num_epochs": 1, @@ -124,6 +126,7 @@ def test_ipo_lora(self, temp_dir): { "path": "Intel/orca_dpo_pairs", "type": "chatml.intel", + "split": "train", }, ], "num_epochs": 1, From 52a227dcb46f02a9f7e52d5b0cd4627fd771b3e9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 22:31:03 -0500 Subject: [PATCH 18/23] fix rebase/merging error --- src/axolotl/utils/data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index f3919fbb6..e5ea3549f 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -864,14 +864,12 @@ def load_split(dataset_cfgs, _cfg): "json", data_files=data_files, split=ds_cfg["split"], - desc="Mapping RL Dataset", ) split_datasets.insert(i, ds) else: ds = load_dataset( # pylint: disable=invalid-name ds_cfg["path"], split=ds_cfg["split"], - desc="Mapping RL Dataset", ) split_datasets.insert(i, ds) From 064b20e38ad1340a212d92e0b4526bc0f4c1743d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 22:49:01 -0500 Subject: [PATCH 19/23] handle edge case w logging --- src/axolotl/train.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 79b880234..5fb873edd 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -47,10 +47,14 @@ def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # load the tokenizer first - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - main_process_only=True, - ) + try: + LOG.debug( + f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", + main_process_only=True, + ) + except RuntimeError: + # sometimes Accelerator() needs to be called un-necessarily before using logging + pass tokenizer = load_tokenizer(cfg) train_dataset = dataset_meta.train_dataset From c49315f9ef7288bd2faeb6e1d33ba17b2ac82560 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 23:04:55 -0500 Subject: [PATCH 20/23] use accelerator for dpo datasets so it doesn't break the logger --- src/axolotl/train.py | 12 ++++-------- src/axolotl/utils/data.py | 11 ++++++----- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5fb873edd..79b880234 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -47,14 +47,10 @@ def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # load the tokenizer first - try: - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - main_process_only=True, - ) - except RuntimeError: - # sometimes Accelerator() needs to be called un-necessarily before using logging - pass + LOG.debug( + f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", + main_process_only=True, + ) tokenizer = load_tokenizer(cfg) train_dataset = dataset_meta.train_dataset diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index e5ea3549f..fb2eb9bc4 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -888,12 +888,13 @@ def load_split(dataset_cfgs, _cfg): return concatenate_datasets(split_datasets) - train_dataset = load_split(cfg.datasets, cfg) + with zero_first(is_main_process()): + train_dataset = load_split(cfg.datasets, cfg) - eval_dataset = None - if cfg.test_datasets: - eval_dataset = load_split(cfg.test_datasets, cfg) - if not eval_dataset: eval_dataset = None + if cfg.test_datasets: + eval_dataset = load_split(cfg.test_datasets, cfg) + if not eval_dataset: + eval_dataset = None return train_dataset, eval_dataset From 72fb877b23859ffab560adcae3cd3bc6465f3d34 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Jan 2024 23:24:58 -0500 Subject: [PATCH 21/23] missing args --- src/axolotl/core/trainer_builder.py | 3 ++- tests/e2e/test_dpo.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 75da7905f..e109db7f8 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1020,7 +1020,8 @@ def build_training_arguments(self, total_num_steps): output_dir=self.cfg.output_dir, warmup_steps=self.cfg.warmup_steps, gradient_checkpointing=self.cfg.gradient_checkpointing, - gradient_checkpointing_kwargs={"use_reentrant": False}, + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + or {"use_reentrant": False}, logging_first_step=True, logging_steps=1, optim=self.cfg.optimizer, diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 0ab042eea..0fd4e8b52 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -56,6 +56,9 @@ def test_dpo_lora(self, temp_dir): "lr_scheduler": "cosine", "max_steps": 20, "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) normalize_config(cfg) @@ -97,6 +100,9 @@ def test_kto_pair_lora(self, temp_dir): "lr_scheduler": "cosine", "max_steps": 20, "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) normalize_config(cfg) @@ -138,6 +144,9 @@ def test_ipo_lora(self, temp_dir): "lr_scheduler": "cosine", "max_steps": 20, "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, } ) normalize_config(cfg) From 29663d8c23dd87198f2fd136cbfb625d8fa4cd04 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jan 2024 00:10:48 -0500 Subject: [PATCH 22/23] validate checkpoint is an adapter for now --- tests/e2e/test_dpo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 0fd4e8b52..ac3c6d069 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -66,7 +66,7 @@ def test_dpo_lora(self, temp_dir): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() @with_temp_dir def test_kto_pair_lora(self, temp_dir): @@ -110,7 +110,7 @@ def test_kto_pair_lora(self, temp_dir): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() @with_temp_dir def test_ipo_lora(self, temp_dir): @@ -154,4 +154,4 @@ def test_ipo_lora(self, temp_dir): dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "adapter_model.bin").exists() + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() From 76b5c2dec3559c8e32aefb7c68333b49857d44f2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jan 2024 00:37:51 -0500 Subject: [PATCH 23/23] log warning when dataset strategy is not loadable --- src/axolotl/prompt_strategies/dpo/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py index 51be84866..3c1c80800 100644 --- a/src/axolotl/prompt_strategies/dpo/__init__.py +++ b/src/axolotl/prompt_strategies/dpo/__init__.py @@ -3,6 +3,9 @@ """ import importlib +import logging + +LOG = logging.getLogger("axolotl") def load(strategy, cfg): @@ -14,4 +17,5 @@ def load(strategy, cfg): load_kwargs = {} return func(cfg, **load_kwargs) except Exception: # pylint: disable=broad-exception-caught + LOG.warning(f"unable to load strategy {strategy}") return None