Skip to content

Commit

Permalink
Fix pretraining with iterable/streaming Dataset (#556)
Browse files Browse the repository at this point in the history
* return without packing prep/len

* fix remove columns

* fix encode arguments

* add error when max steps not set

* fix test

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
  • Loading branch information
jphme and jphme authored Sep 13, 2023
1 parent 9845c5e commit 2f586d1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def validate_config(cfg):
LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if cfg.pretraining_dataset and not cfg.max_steps:
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)

if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
Expand Down
19 changes: 14 additions & 5 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import logging
from pathlib import Path
from typing import Tuple, Union
from typing import Dict, List, Tuple, Union

import torch
from datasets import (
Expand Down Expand Up @@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps

with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
Expand Down Expand Up @@ -527,9 +528,11 @@ def load_prepare_datasets(
return train_dataset, eval_dataset


def encode_pretraining(tokenizer, max_tokens, examples):
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples,
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
Expand Down Expand Up @@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=[
"text",
],
)
return dataset
2 changes: 1 addition & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_encode_pretraining(self):
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])

self.assertEqual(len(result["input_ids"]), 3)

Expand Down

0 comments on commit 2f586d1

Please sign in to comment.