From 2ce5c0d68a9cf1db1fa9a7861c6ed3bb117c5f8b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 20 Jan 2024 05:11:50 -0500 Subject: [PATCH] Deprecate max packed sequence len (#1141) --- README.md | 4 - src/axolotl/utils/config.py | 15 +--- src/axolotl/utils/data.py | 137 ++++------------------------------- src/axolotl/utils/models.py | 6 +- src/axolotl/utils/trainer.py | 21 +++--- tests/test_validation.py | 25 ++----- 6 files changed, 38 insertions(+), 170 deletions(-) diff --git a/README.md b/README.md index 89cf1ef36..f57765290 100644 --- a/README.md +++ b/README.md @@ -642,10 +642,6 @@ sequence_len: 2048 # Pad inputs so each step uses constant sized buffers # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently pad_to_sequence_len: -# Max sequence length to concatenate training samples together up to -# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning -# FutureWarning: This will soon be DEPRECATED -max_packed_sequence_len: 1024 # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' sample_packing: # Set to 'false' if getting errors during eval with sample_packing on. diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index beade8621..ca7d037dd 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -157,6 +157,9 @@ def normalize_config(cfg): if isinstance(cfg.learning_rate, str): cfg.learning_rate = float(cfg.learning_rate) + if isinstance(cfg.pretraining_dataset, dict): + cfg.pretraining_dataset = [cfg.pretraining_dataset] + log_gpu_memory_usage(LOG, "baseline", cfg.device) @@ -192,18 +195,8 @@ def validate_config(cfg): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) - if cfg.max_packed_sequence_len and cfg.sample_packing: - raise ValueError( - "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing" - ) if cfg.max_packed_sequence_len: - LOG.warning( - str( - PendingDeprecationWarning( - "max_packed_sequence_len will be deprecated in favor of sample_packing" - ) - ) - ) + raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") if cfg.sample_packing and not cfg.pad_to_sequence_len: LOG.warning( diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 1eff82694..00c1fc16f 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -19,7 +19,7 @@ from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset +from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, @@ -71,9 +71,11 @@ def prepare_dataset(cfg, tokenizer): else: path = cfg.pretraining_dataset name = None - if isinstance(cfg.pretraining_dataset, dict): - path = cfg.pretraining_dataset["path"] - name = cfg.pretraining_dataset["name"] + if isinstance(cfg.pretraining_dataset, list) and isinstance( + cfg.pretraining_dataset[0], dict + ): + path = cfg.pretraining_dataset[0]["path"] + name = cfg.pretraining_dataset[0]["name"] train_dataset = load_pretraining_dataset( path, @@ -88,11 +90,6 @@ def prepare_dataset(cfg, tokenizer): eval_dataset = None return train_dataset, eval_dataset, cfg.max_steps, prompters - with zero_first(is_main_process()): - train_dataset, eval_dataset = process_datasets_for_packing( - cfg, train_dataset, eval_dataset, tokenizer - ) - if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: @@ -163,6 +160,10 @@ def load_tokenized_prepared_datasets( else: LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") LOG.info("Loading raw datasets...") + if not cfg.is_preprocess: + LOG.warning( + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset" + ) if cfg.seed: seed = cfg.seed @@ -382,6 +383,9 @@ def for_d_in_datasets(dataset_configs): if len(datasets) > 1: LOG.info("shuffle merged datasets") dataset = dataset.shuffle(seed=seed) + + dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer) + if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(prepared_ds_path) @@ -419,119 +423,9 @@ def load_prepare_datasets( cfg, default_dataset_prepared_path, ) -> Tuple[Dataset, Dataset, List[Prompter]]: - max_packed_sequence_len = ( - cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len + dataset, prompters = load_tokenized_prepared_datasets( + tokenizer, cfg, default_dataset_prepared_path ) - max_packed_sequence_len = min( - max_packed_sequence_len, cfg.sequence_len - ) # make sure we don't accidentally set it larger than sequence_len - - tokenizer_name = tokenizer.__class__.__name__ - prompters: List[Prompter] = [] - if cfg.max_packed_sequence_len is not None: - # see if we can go ahead and load the stacked dataset - seed = f"@{str(cfg.seed)}" if cfg.seed else "" - ds_hash = str( - md5( - ( - str(cfg.sequence_len) - + "@" - + str(max_packed_sequence_len) - + seed - + "|".join( - sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) - ) - + "|" - + tokenizer_name - ) - ) - ) - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(default_dataset_prepared_path) / ds_hash - ) - - dataset = None - use_auth_token = cfg.hf_use_auth_token - try: - if cfg.push_dataset_to_hub: - LOG.info( - f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", - token=use_auth_token, - ) - dataset = dataset["train"] - except Exception: # pylint: disable=broad-except # nosec - pass - - if dataset: - ... - elif ( - cfg.dataset_prepared_path - and any(prepared_ds_path.glob("*")) - and not cfg.is_preprocess - ): - LOG.info( - f"Loading prepared packed dataset from disk at {prepared_ds_path}..." - ) - dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared packed dataset loaded from disk...") - if cfg.push_dataset_to_hub: - LOG.info( - f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True - ) - else: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path - ) - - if cfg.seed: - dataset = dataset.shuffle(seed=cfg.seed) - - constant_len_dataset = ConstantLengthDataset( - tokenizer, - [dataset], - seq_length=max_packed_sequence_len, - ) - LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") - dataset = Dataset.from_list(list(constant_len_dataset)) - - # filter out bad data - # TODO convert to dataset.filter(...) - dataset = Dataset.from_list( - [ - d - for d in dataset - if len(d["input_ids"]) <= cfg.sequence_len - and len(d["input_ids"]) > 0 - and len(d["input_ids"]) == len(d["attention_mask"]) - and len(d["input_ids"]) == len(d["labels"]) - ] - ) - - if cfg.local_rank == 0: - LOG.info( - f"Saving packed prepared dataset to disk... {prepared_ds_path}" - ) - dataset.save_to_disk(prepared_ds_path) - if cfg.push_dataset_to_hub: - LOG.info( - f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" - ) - dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", - private=True, - ) - else: - dataset, prompters = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path - ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: LOG.info( @@ -877,6 +771,7 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s dataset = dataset.map( encode, batched=True, + batch_size=10_000, input_columns="text", # remove all the existing columns after mapping since they end up having # a different length than the encoded/tokenized column diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 034249932..afc41e1ed 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -329,11 +329,7 @@ def load_model( LOG.info("patching mixtral with flash attention") replace_mixtral_attn_with_multipack_flash_attn() - if ( - cfg.is_llama_derived_model - and (cfg.max_packed_sequence_len or cfg.sample_packing) - and not inference - ): + if cfg.is_llama_derived_model and cfg.sample_packing and not inference: from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask LOG.info("patching _expand_mask") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 44642fb30..956861b29 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -81,6 +81,15 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True): return weighted_cross_entropy(logits, labels, weights) +@contextmanager +def disable_datasets_caching(): + try: + set_caching_enabled(False) + yield + finally: + set_caching_enabled(True) + + def add_position_ids(sample): sample_len = len(sample["input_ids"]) sample["position_ids"] = torch.arange(len(sample["input_ids"])) @@ -97,15 +106,6 @@ def drop_long_seq(sample, sequence_len=2048): return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0 -@contextmanager -def disable_datasets_caching(): - try: - set_caching_enabled(False) - yield - finally: - set_caching_enabled(True) - - def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) with zero_first(is_main_process()): @@ -227,8 +227,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): sampler=RandomSampler(train_dataset), batch_size=cfg.micro_batch_size, drop_last=True, - batch_max_len=cfg.micro_batch_size - * (cfg.max_packed_sequence_len or cfg.sequence_len), + batch_max_len=cfg.micro_batch_size * cfg.sequence_len, lengths=get_dataset_lengths(train_dataset), ) diff --git a/tests/test_validation.py b/tests/test_validation.py index 41e4b1253..5201bdf46 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -324,20 +324,19 @@ def test_adamw_hyperparams(self): validate_config(cfg) - def test_packing(self): + def test_deprecated_packing(self): cfg = DictDefault( { - "max_packed_sequence_len": 2048, + "max_packed_sequence_len": 1024, } ) - with self._caplog.at_level(logging.WARNING): + with pytest.raises( + DeprecationWarning, + match=r"`max_packed_sequence_len` is no longer supported", + ): validate_config(cfg) - assert any( - "max_packed_sequence_len will be deprecated in favor of sample_packing" - in record.message - for record in self._caplog.records - ) + def test_packing(self): cfg = DictDefault( { "sample_packing": True, @@ -352,16 +351,6 @@ def test_packing(self): for record in self._caplog.records ) - cfg = DictDefault( - { - "max_packed_sequence_len": 2048, - "sample_packing": True, - } - ) - regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" - with pytest.raises(ValueError, match=regex_exp): - validate_config(cfg) - @pytest.mark.skipif( is_torch_bf16_gpu_available(), reason="test should only run on gpus w/o bf16 support",