-
Notifications
You must be signed in to change notification settings - Fork 469
/
prepare_slimpajama.py
105 lines (84 loc) · 3.07 KB
/
prepare_slimpajama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import glob
import os
from pathlib import Path
import sys
from typing import List
import numpy as np
from tqdm import tqdm
from multiprocessing import Process, cpu_count
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
import lit_gpt.packed_dataset as packed_dataset
from lit_gpt import Tokenizer
# Filename for SlimPajama
slimpajama_sets = {
"train": "train/chunk*/*",
"validation": "validation/chunk*/*",
"test": "test/chunk*/*",
}
def prepare_full(
source_path: Path,
tokenizer_path: Path,
destination_path: Path,
chunk_size: int,
split: str="train",
filenames_subset: List[str] = None,
process_id: int = 0
) -> None:
import zstandard as zstd
destination_path.mkdir(parents=True, exist_ok=True)
tokenizer = Tokenizer(tokenizer_path)
# Use the provided filenames_subset or default to all filenames
filenames = filenames_subset
if not filenames:
raise RuntimeError(
f"No files matching {slimpajama_sets[split]} found at {source_path}. \n"
"Make sure you download the data..."
)
builder = packed_dataset.PackedDatasetBuilder(
outdir=destination_path,
prefix=f"{split}_slimpajama_{process_id}", # Use process_id to differentiate builders
chunk_size=chunk_size,
sep_token=tokenizer.bos_id,
dtype="auto",
vocab_size=tokenizer.vocab_size,
)
for filepath in filenames:
print(f"Processing {filepath}")
with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
for row in tqdm(f):
text = json.loads(row)["text"]
if json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaGithub":
continue # we don't want to include the github data
text_ids = tokenizer.encode(text)
builder.add_array(np.array(text_ids, dtype=builder.dtype))
builder.write_reminder()
def prepare(
source_path: Path = Path("data/RedPajama-Data-1T-Sample"),
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
destination_path: Path = Path("data/red_pajama_sample"),
chunk_size: int = 2049 * 1024,
split: str="train",
percentage: float = 1.0,
) -> None:
import time
filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True)
filenames = filenames[:int(len(filenames) * percentage)]
num_processes = cpu_count()
chunked_filenames = np.array_split(filenames, num_processes)
processes = []
start_time = time.time()
for i, subset in enumerate(chunked_filenames):
p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i))
processes.append(p)
p.start()
for p in processes:
p.join()
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time taken: {elapsed_time:.2f} seconds")
if __name__ == "__main__":
from jsonargparse import CLI
CLI(prepare)