From fd954ce0bd3d7c0c04c4ce573ed352bf519ea34f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jan 2024 21:09:27 -0500 Subject: [PATCH 1/6] deprecate max_packed_sequence_len, map and filter dataset during pre-processing step --- README.md | 4 -- src/axolotl/utils/config.py | 12 ---- src/axolotl/utils/data.py | 124 ++--------------------------------- src/axolotl/utils/models.py | 6 +- src/axolotl/utils/trainer.py | 21 +++--- tests/test_validation.py | 23 ------- 6 files changed, 17 insertions(+), 173 deletions(-) diff --git a/README.md b/README.md index f5f848a44..4c79d578f 100644 --- a/README.md +++ b/README.md @@ -639,10 +639,6 @@ sequence_len: 2048 # Pad inputs so each step uses constant sized buffers # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently pad_to_sequence_len: -# Max sequence length to concatenate training samples together up to -# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning -# FutureWarning: This will soon be DEPRECATED -max_packed_sequence_len: 1024 # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' sample_packing: # Set to 'false' if getting errors during eval with sample_packing on. diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 7ea9581a9..3fd87fee5 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -186,18 +186,6 @@ def validate_config(cfg): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) - if cfg.max_packed_sequence_len and cfg.sample_packing: - raise ValueError( - "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing" - ) - if cfg.max_packed_sequence_len: - LOG.warning( - str( - PendingDeprecationWarning( - "max_packed_sequence_len will be deprecated in favor of sample_packing" - ) - ) - ) if cfg.sample_packing and not cfg.pad_to_sequence_len: LOG.warning( diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 15ae8d5a5..6b896cdbd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -19,7 +19,7 @@ from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset +from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, @@ -88,11 +88,6 @@ def prepare_dataset(cfg, tokenizer): eval_dataset = None return train_dataset, eval_dataset, cfg.max_steps, prompters - with zero_first(is_main_process()): - train_dataset, eval_dataset = process_datasets_for_packing( - cfg, train_dataset, eval_dataset, tokenizer - ) - if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: @@ -382,6 +377,9 @@ def for_d_in_datasets(dataset_configs): if len(datasets) > 1: LOG.info("shuffle merged datasets") dataset = dataset.shuffle(seed=seed) + + dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer) + if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(prepared_ds_path) @@ -419,119 +417,9 @@ def load_prepare_datasets( cfg, default_dataset_prepared_path, ) -> Tuple[Dataset, Dataset, List[Prompter]]: - max_packed_sequence_len = ( - cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len + dataset, prompters = load_tokenized_prepared_datasets( + tokenizer, cfg, default_dataset_prepared_path ) - max_packed_sequence_len = min( - max_packed_sequence_len, cfg.sequence_len - ) # make sure we don't accidentally set it larger than sequence_len - - tokenizer_name = tokenizer.__class__.__name__ - prompters: List[Prompter] = [] - if cfg.max_packed_sequence_len is not None: - # see if we can go ahead and load the stacked dataset - seed = f"@{str(cfg.seed)}" if cfg.seed else "" - ds_hash = str( - md5( - ( - str(cfg.sequence_len) - + "@" - + str(max_packed_sequence_len) - + seed - + "|".join( - sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) - ) - + "|" - + tokenizer_name - ) - ) - ) - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(default_dataset_prepared_path) / ds_hash - ) - - dataset = None - use_auth_token = cfg.hf_use_auth_token - try: - if cfg.push_dataset_to_hub: - LOG.info( - f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", - token=use_auth_token, - ) - dataset = dataset["train"] - except Exception: # pylint: disable=broad-except # nosec - pass - - if dataset: - ... - elif ( - cfg.dataset_prepared_path - and any(prepared_ds_path.glob("*")) - and not cfg.is_preprocess - ): - LOG.info( - f"Loading prepared packed dataset from disk at {prepared_ds_path}..." - ) - dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared packed dataset loaded from disk...") - if cfg.push_dataset_to_hub: - LOG.info( - f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True - ) - else: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path - ) - - if cfg.seed: - dataset = dataset.shuffle(seed=cfg.seed) - - constant_len_dataset = ConstantLengthDataset( - tokenizer, - [dataset], - seq_length=max_packed_sequence_len, - ) - LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") - dataset = Dataset.from_list(list(constant_len_dataset)) - - # filter out bad data - # TODO convert to dataset.filter(...) - dataset = Dataset.from_list( - [ - d - for d in dataset - if len(d["input_ids"]) <= cfg.sequence_len - and len(d["input_ids"]) > 0 - and len(d["input_ids"]) == len(d["attention_mask"]) - and len(d["input_ids"]) == len(d["labels"]) - ] - ) - - if cfg.local_rank == 0: - LOG.info( - f"Saving packed prepared dataset to disk... {prepared_ds_path}" - ) - dataset.save_to_disk(prepared_ds_path) - if cfg.push_dataset_to_hub: - LOG.info( - f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", - private=True, - ) - else: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path - ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: LOG.info( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 55721f820..60ed7d12a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -301,11 +301,7 @@ def load_model( LOG.info("patching with flash attention") replace_mixtral_attn_with_multipack_flash_attn() - if ( - cfg.is_llama_derived_model - and (cfg.max_packed_sequence_len or cfg.sample_packing) - and not inference - ): + if cfg.is_llama_derived_model and cfg.sample_packing and not inference: from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask LOG.info("patching _expand_mask") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3fc244605..9a9eeab4b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -81,6 +81,15 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True): return weighted_cross_entropy(logits, labels, weights) +@contextmanager +def disable_datasets_caching(): + try: + set_caching_enabled(False) + yield + finally: + set_caching_enabled(True) + + def add_position_ids(sample): sample_len = len(sample["input_ids"]) sample["position_ids"] = torch.arange(len(sample["input_ids"])) @@ -97,15 +106,6 @@ def drop_long_seq(sample, sequence_len=2048): return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 -@contextmanager -def disable_datasets_caching(): - try: - set_caching_enabled(False) - yield - finally: - set_caching_enabled(True) - - def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) with zero_first(is_main_process()): @@ -226,8 +226,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): sampler=RandomSampler(train_dataset), batch_size=cfg.micro_batch_size, drop_last=True, - batch_max_len=cfg.micro_batch_size - * (cfg.max_packed_sequence_len or cfg.sequence_len), + batch_max_len=cfg.micro_batch_size * cfg.sequence_len, lengths=get_dataset_lengths(train_dataset), ) diff --git a/tests/test_validation.py b/tests/test_validation.py index 41e4b1253..276b5e6a6 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -325,19 +325,6 @@ def test_adamw_hyperparams(self): validate_config(cfg) def test_packing(self): - cfg = DictDefault( - { - "max_packed_sequence_len": 2048, - } - ) - with self._caplog.at_level(logging.WARNING): - validate_config(cfg) - assert any( - "max_packed_sequence_len will be deprecated in favor of sample_packing" - in record.message - for record in self._caplog.records - ) - cfg = DictDefault( { "sample_packing": True, @@ -352,16 +339,6 @@ def test_packing(self): for record in self._caplog.records ) - cfg = DictDefault( - { - "max_packed_sequence_len": 2048, - "sample_packing": True, - } - ) - regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - @pytest.mark.skipif( is_torch_bf16_gpu_available(), reason="test should only run on gpus w/o bf16 support", From 4a8392f86e044b9c28f83dae9b1d538184c3c1a8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jan 2024 21:45:36 -0500 Subject: [PATCH 2/6] increase batch size for packed pretraining and be more robust on data configs --- src/axolotl/utils/data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6b896cdbd..fc4409fc5 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -74,6 +74,11 @@ def prepare_dataset(cfg, tokenizer): if isinstance(cfg.pretraining_dataset, dict): path = cfg.pretraining_dataset["path"] name = cfg.pretraining_dataset["name"] + elif isinstance(cfg.pretraining_dataset, list) and isinstance( + cfg.pretraining_dataset[0], dict + ): + path = cfg.pretraining_dataset[0]["path"] + name = cfg.pretraining_dataset[0]["name"] train_dataset = load_pretraining_dataset( path, @@ -760,6 +765,7 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s dataset = dataset.map( encode, batched=True, + batch_size=10_000, input_columns="text", # remove all the existing columns after mapping since they end up having # a different length than the encoded/tokenized column From 585c11d25ef5f2564584ea956df36ace2dc2363f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Jan 2024 23:21:11 -0500 Subject: [PATCH 3/6] deprecation tests --- src/axolotl/utils/config.py | 2 ++ tests/test_validation.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 3fd87fee5..0eaab56a0 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -186,6 +186,8 @@ def validate_config(cfg): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) + if cfg.max_packed_sequence_len: + raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") if cfg.sample_packing and not cfg.pad_to_sequence_len: LOG.warning( diff --git a/tests/test_validation.py b/tests/test_validation.py index 276b5e6a6..5201bdf46 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -324,6 +324,18 @@ def test_adamw_hyperparams(self): validate_config(cfg) + def test_deprecated_packing(self): + cfg = DictDefault( + { + "max_packed_sequence_len": 1024, + } + ) + with pytest.raises( + DeprecationWarning, + match=r"`max_packed_sequence_len` is no longer supported", + ): + validate_config(cfg) + def test_packing(self): cfg = DictDefault( { From a787d57f0725e0837c7ac23110619ee245d8a09e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Jan 2024 23:29:32 -0500 Subject: [PATCH 4/6] normalize pretraining_dataset configuration --- src/axolotl/utils/config.py | 3 +++ src/axolotl/utils/data.py | 5 +---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 0eaab56a0..be4ddf978 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -155,6 +155,9 @@ def normalize_config(cfg): if isinstance(cfg.learning_rate, str): cfg.learning_rate = float(cfg.learning_rate) + if isinstance(cfg.pretraining_dataset, dict): + cfg.pretraining_dataset = [cfg.pretraining_dataset] + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index fc4409fc5..a0fd3ea1a 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -71,10 +71,7 @@ def prepare_dataset(cfg, tokenizer): else: path = cfg.pretraining_dataset name = None - if isinstance(cfg.pretraining_dataset, dict): - path = cfg.pretraining_dataset["path"] - name = cfg.pretraining_dataset["name"] - elif isinstance(cfg.pretraining_dataset, list) and isinstance( + if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"] From 334f02cd017ba5ac30aef33feb93b5a4eedb8f29 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Jan 2024 23:32:24 -0500 Subject: [PATCH 5/6] warn about not pre-processing --- src/axolotl/utils/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9a9eeab4b..6d2f08c8e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -107,6 +107,10 @@ def drop_long_seq(sample, sequence_len=2048): def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): + if cfg.is_preprocess: + LOG.warning( + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset" + ) drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) with zero_first(is_main_process()): if cfg.group_by_length: From 90d73fc439fac6df8b9b1c52b58d02d582811393 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Jan 2024 23:40:02 -0500 Subject: [PATCH 6/6] log warning sooner if not pre-processed before training --- src/axolotl/utils/data.py | 4 ++++ src/axolotl/utils/trainer.py | 4 ---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index a0fd3ea1a..c1c1fbc64 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -160,6 +160,10 @@ def load_tokenized_prepared_datasets( else: LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") LOG.info("Loading raw datasets...") + if not cfg.is_preprocess: + LOG.warning( + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset" + ) if cfg.seed: seed = cfg.seed diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6d2f08c8e..9a9eeab4b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -107,10 +107,6 @@ def drop_long_seq(sample, sequence_len=2048): def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): - if cfg.is_preprocess: - LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset" - ) drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) with zero_first(is_main_process()): if cfg.group_by_length: