Skip to content

Commit

Permalink
Deprecate max packed sequence len (#1141)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 20, 2024
1 parent 3db5f2f commit 2ce5c0d
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 170 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,6 @@ sequence_len: 2048
# Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:
# Max sequence length to concatenate training samples together up to
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing:
# Set to 'false' if getting errors during eval with sample_packing on.
Expand Down
15 changes: 4 additions & 11 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def normalize_config(cfg):
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)

if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]

log_gpu_memory_usage(LOG, "baseline", cfg.device)


Expand Down Expand Up @@ -192,18 +195,8 @@ def validate_config(cfg):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
)
if cfg.max_packed_sequence_len:
LOG.warning(
str(
PendingDeprecationWarning(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
)
)
)
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")

if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning(
Expand Down
137 changes: 16 additions & 121 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import PreTrainedTokenizerBase

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
Expand Down Expand Up @@ -71,9 +71,11 @@ def prepare_dataset(cfg, tokenizer):
else:
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, dict):
path = cfg.pretraining_dataset["path"]
name = cfg.pretraining_dataset["name"]
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]

train_dataset = load_pretraining_dataset(
path,
Expand All @@ -88,11 +90,6 @@ def prepare_dataset(cfg, tokenizer):
eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps, prompters

with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer
)

if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
Expand Down Expand Up @@ -163,6 +160,10 @@ def load_tokenized_prepared_datasets(
else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
)

if cfg.seed:
seed = cfg.seed
Expand Down Expand Up @@ -382,6 +383,9 @@ def for_d_in_datasets(dataset_configs):
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)

dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer)

if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
Expand Down Expand Up @@ -419,119 +423,9 @@ def load_prepare_datasets(
cfg,
default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len

tokenizer_name = tokenizer.__class__.__name__
prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str(
md5(
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|"
+ tokenizer_name
)
)
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(default_dataset_prepared_path) / ds_hash
)

dataset = None
use_auth_token = cfg.hf_use_auth_token
try:
if cfg.push_dataset_to_hub:
LOG.info(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec
pass

if dataset:
...
elif (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
):
LOG.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared packed dataset loaded from disk...")
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
)
else:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)

if cfg.seed:
dataset = dataset.shuffle(seed=cfg.seed)

constant_len_dataset = ConstantLengthDataset(
tokenizer,
[dataset],
seq_length=max_packed_sequence_len,
)
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
dataset = Dataset.from_list(list(constant_len_dataset))

# filter out bad data
# TODO convert to dataset.filter(...)
dataset = Dataset.from_list(
[
d
for d in dataset
if len(d["input_ids"]) <= cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)

if cfg.local_rank == 0:
LOG.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
LOG.info(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}",
private=True,
)
else:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)

if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
LOG.info(
Expand Down Expand Up @@ -877,6 +771,7 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s
dataset = dataset.map(
encode,
batched=True,
batch_size=10_000,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
Expand Down
6 changes: 1 addition & 5 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,7 @@ def load_model(
LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()

if (
cfg.is_llama_derived_model
and (cfg.max_packed_sequence_len or cfg.sample_packing)
and not inference
):
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask

LOG.info("patching _expand_mask")
Expand Down
21 changes: 10 additions & 11 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
return weighted_cross_entropy(logits, labels, weights)


@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)


def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
Expand All @@ -97,15 +106,6 @@ def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0


@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)


def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
with zero_first(is_main_process()):
Expand Down Expand Up @@ -227,8 +227,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
sampler=RandomSampler(train_dataset),
batch_size=cfg.micro_batch_size,
drop_last=True,
batch_max_len=cfg.micro_batch_size
* (cfg.max_packed_sequence_len or cfg.sequence_len),
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
lengths=get_dataset_lengths(train_dataset),
)

Expand Down
25 changes: 7 additions & 18 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,20 +324,19 @@ def test_adamw_hyperparams(self):

validate_config(cfg)

def test_packing(self):
def test_deprecated_packing(self):
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"max_packed_sequence_len": 1024,
}
)
with self._caplog.at_level(logging.WARNING):
with pytest.raises(
DeprecationWarning,
match=r"`max_packed_sequence_len` is no longer supported",
):
validate_config(cfg)
assert any(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
in record.message
for record in self._caplog.records
)

def test_packing(self):
cfg = DictDefault(
{
"sample_packing": True,
Expand All @@ -352,16 +351,6 @@ def test_packing(self):
for record in self._caplog.records
)

cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"sample_packing": True,
}
)
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)

@pytest.mark.skipif(
is_torch_bf16_gpu_available(),
reason="test should only run on gpus w/o bf16 support",
Expand Down

0 comments on commit 2ce5c0d

Please sign in to comment.