Skip to content

Commit 4c84bc7

Browse files
authored
Merge pull request karpathy#740 from karpathy/gordicaleksa-fix_dataloader2
Gordicaleksa fix dataloader2
2 parents 1787210 + 755458d commit 4c84bc7

6 files changed

+164
-80
lines changed

dev/data/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ The idea is that each dataset has a .py file here in the root of `dev/data`, and
66
- running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it
77

88
And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these.
9+
10+
Note: we support "gpt-2" and "llama" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default.

dev/data/data_common.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,38 @@ def download_file(url: str, fname: str, chunk_size=1024):
2323
bar.update(size)
2424

2525

26-
def write_datafile(filename, toks):
26+
HEADERS_INFO = {
27+
"gpt-2": {
28+
"magic": 20240520,
29+
"version": 1,
30+
"token_dtype": np.uint16,
31+
},
32+
"llama-3": {
33+
"magic": 20240801,
34+
"version": 7,
35+
"token_dtype": np.uint32,
36+
},
37+
}
38+
39+
def write_datafile(filename, toks, model_desc="gpt-2"):
2740
"""
2841
Saves token data as a .bin file, for reading in C.
2942
- First comes a header with 256 int32s
30-
- The tokens follow, each as a uint16
43+
- The tokens follow, each as uint16 (gpt-2) or uint32 (llama)
3144
"""
3245
assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
46+
assert model_desc in ["gpt-2", "llama-3"], f"unknown model descriptor {model_desc}"
47+
info = HEADERS_INFO[model_desc]
3348
# construct the header
34-
header = np.zeros(256, dtype=np.int32)
35-
header[0] = 20240520 # magic
36-
header[1] = 1 # version
37-
header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
38-
# construct the tokens numpy array, if not already
39-
if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
40-
# validate that no token exceeds a uint16
41-
maxtok = 2**16
42-
assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
43-
toks_np = np.array(toks, dtype=np.uint16)
44-
else:
45-
toks_np = toks
49+
header = np.zeros(256, dtype=np.int32) # header is always 256 int32 values
50+
header[0] = info["magic"]
51+
header[1] = info["version"]
52+
header[2] = len(toks) # number of tokens after the 256*4 bytes of header
53+
# construct the data (numpy array of tokens)
54+
toks_np = np.array(toks, dtype=info["token_dtype"])
4655
# write to file
47-
print(f"writing {len(toks):,} tokens to {filename}")
56+
num_bytes = (256 * 4) + (len(toks) * toks_np.itemsize)
57+
print(f"writing {len(toks):,} tokens to {filename} ({num_bytes:,} bytes) in the {model_desc} format")
4858
with open(filename, "wb") as f:
4959
f.write(header.tobytes())
5060
f.write(toks_np.tobytes())

dev/data/fineweb.py

+41-11
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,22 @@
2222
import os
2323
import argparse
2424
import multiprocessing as mp
25+
2526
import numpy as np
2627
import tiktoken
2728
from datasets import load_dataset
2829
from tqdm import tqdm
29-
import argparse
30+
31+
from transformers import AutoTokenizer
32+
3033

3134
from data_common import write_datafile
3235
# ------------------------------------------
3336

3437
parser = argparse.ArgumentParser(description="FineWeb and Edu-FineWeb dataset preprocessing")
3538
parser.add_argument("-t", "--type", type=str, default="classic", help="Fineweb type, edu|classic")
3639
parser.add_argument("-v", "--version", type=str, default="10B", help="Fineweb data sample size, 10B|100B")
40+
parser.add_argument("-m", "--model_desc", type=str, default="gpt-2", help="Model descriptor, gpt-2|llama-3")
3741
parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each data shard in the output .bin files, in tokens")
3842
args = parser.parse_args()
3943

@@ -60,26 +64,52 @@
6064
fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
6165
name = "edu_fineweb"
6266

63-
# init the tokenizer
64-
enc = tiktoken.get_encoding("gpt2")
65-
eot = enc._special_tokens['<|endoftext|>'] # end of text token
66-
def tokenize(doc):
67+
def tokenize_llama(doc):
68+
# tokenizes a single document and returns a numpy array of uint32 tokens
69+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
70+
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
71+
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
72+
tokens = [eot] # the special <|endoftext|> token delimits all documents
73+
tokens.extend(encode(doc["text"]))
74+
tokens_np = np.array(tokens)
75+
assert (0 <= tokens_np).all() and (tokens_np < 2**32).all(), "token dictionary too large for uint32"
76+
tokens_np_uint = tokens_np.astype(np.uint32)
77+
return tokens_np_uint
78+
79+
def tokenize_gpt2(doc):
6780
# tokenizes a single document and returns a numpy array of uint16 tokens
81+
enc = tiktoken.get_encoding("gpt2")
82+
encode = lambda s: enc.encode_ordinary(s)
83+
eot = enc._special_tokens['<|endoftext|>'] # end of text token
6884
tokens = [eot] # the special <|endoftext|> token delimits all documents
69-
tokens.extend(enc.encode_ordinary(doc["text"]))
85+
tokens.extend(encode(doc["text"]))
7086
tokens_np = np.array(tokens)
7187
assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
72-
tokens_np_uint16 = tokens_np.astype(np.uint16)
73-
return tokens_np_uint16
88+
tokens_np_uint = tokens_np.astype(np.uint16)
89+
return tokens_np_uint
90+
91+
token_dtype = {
92+
"gpt-2": np.uint16,
93+
"llama-3": np.uint32
94+
}[args.model_desc]
7495

7596
# tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
7697
nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
7798
with mp.Pool(nprocs) as pool:
7899
shard_index = 0
79100
# preallocate buffer to hold current shard
80-
all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
101+
all_tokens_np = np.empty((args.shard_size,), dtype=token_dtype)
81102
token_count = 0
82103
progress_bar = None
104+
105+
tokenize = lambda x: None
106+
if args.model_desc == "gpt-2":
107+
tokenize = tokenize_gpt2
108+
elif args.model_desc == "llama-3":
109+
tokenize = tokenize_llama
110+
else:
111+
raise ValueError(f"unknown model {args.model_desc}")
112+
83113
for tokens in pool.imap(tokenize, fw, chunksize=16):
84114

85115
# is there enough space in the current shard for the new tokens?
@@ -99,7 +129,7 @@ def tokenize(doc):
99129
remainder = args.shard_size - token_count
100130
progress_bar.update(remainder)
101131
all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
102-
write_datafile(filename, all_tokens_np)
132+
write_datafile(filename, all_tokens_np.tolist(), args.model_desc)
103133
shard_index += 1
104134
progress_bar = None
105135
# populate the next shard with the leftovers of the current doc
@@ -110,4 +140,4 @@ def tokenize(doc):
110140
if token_count != 0:
111141
split = "val" if shard_index == 0 else "train"
112142
filename = os.path.join(DATA_CACHE_DIR, f"{name}_{split}_{shard_index:06d}.bin")
113-
write_datafile(filename, all_tokens_np[:token_count])
143+
write_datafile(filename, (all_tokens_np[:token_count]).tolist(), args.model_desc)

dev/data/tinyshakespeare.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,32 @@
66
The output is written to a newly created tinyshakespeare/ folder.
77
The script prints:
88
9-
Saved 32768 tokens to tinyshakespeare/tiny_shakespeare_val.bin
10-
Saved 305260 tokens to tinyshakespeare/tiny_shakespeare_train.bin
9+
For GPT-2:
10+
$ python dev/data/tinyshakespeare.py --model=gpt-2
11+
writing 32,768 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_val.bin (66,560 bytes) in the gpt-2 format
12+
writing 305,260 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_train.bin (611,544 bytes) in the gpt-2 format
13+
14+
For LLaMA 3:
15+
$ python dev/data/tinyshakespeare.py --model=llama-3
16+
writing 32,768 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_val.bin (132,096 bytes) in the llama-3 format
17+
writing 276,224 tokens to /home/ubuntu/llm.c/dev/data/tinyshakespeare/tiny_shakespeare_train.bin (1,105,920 bytes) in the llama-3 format
1118
1219
And runs in a few seconds depending on your internet
1320
connection and computer. The .bin files are raw byte
14-
streams of int32 numbers indicating the token ids.
21+
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
1522
"""
1623

24+
import argparse
1725
import os
26+
1827
import tiktoken
19-
import numpy as np
28+
from transformers import AutoTokenizer
29+
2030
from data_common import download_file, write_datafile
2131

2232
# -----------------------------------------------------------------------------
2333
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinyshakespeare")
2434

25-
enc = tiktoken.get_encoding("gpt2")
26-
encode = lambda s: enc.encode(s, allowed_special={'<|endoftext|>'})
27-
2835
def download():
2936
"""Downloads the TinyShakespeare dataset to DATA_CACHE_DIR"""
3037
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
@@ -37,23 +44,41 @@ def download():
3744
else:
3845
print(f"{data_filename} already exists, skipping download...")
3946

40-
def tokenize():
47+
def tokenize(model_desc):
48+
if model_desc == "gpt-2":
49+
enc = tiktoken.get_encoding("gpt2")
50+
encode = lambda s: enc.encode_ordinary(s)
51+
eot = enc._special_tokens['<|endoftext|>'] # end of text token
52+
elif model_desc == "llama-3":
53+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
54+
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
55+
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
56+
else:
57+
raise ValueError(f"unknown model descriptor {model_desc}")
4158
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
4259
text = open(data_filename, 'r').read()
43-
# let's treat every person's statement in the dialog as a separate document
44-
text = "<|endoftext|>" + text
45-
text = text.replace('\n\n', '\n\n<|endoftext|>')
46-
# encode the text
47-
tokens = encode(text)
60+
# let's treat every individual chunk of text as a separate "document"
61+
sections = text.split("\n\n")
62+
tokens = []
63+
for i, s in enumerate(sections):
64+
tokens.append(eot)
65+
# there was a mild bug where I originally intended to remove \n\n, but instead just added
66+
# the EOT right after each \n\n, so I'm keeping that behavior for backwards compatibility
67+
# therefore we have to here add an extra \n\n at the end of each section, except the last
68+
spad = s + "\n\n" if i != len(sections) - 1 else s
69+
tokens.extend(encode(spad))
4870
# let's take the first 32,768 tokens as the validation split (~10%)
4971
val_tokens = tokens[:32768]
5072
train_tokens = tokens[32768:]
5173
# save to file
5274
val_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_val.bin")
5375
train_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_train.bin")
54-
write_datafile(val_filename, val_tokens)
55-
write_datafile(train_filename, train_tokens)
76+
write_datafile(val_filename, val_tokens, model_desc)
77+
write_datafile(train_filename, train_tokens, model_desc)
5678

5779
if __name__ == "__main__":
80+
parser = argparse.ArgumentParser(description="Tiny Shakespeare dataset preprocessing")
81+
parser.add_argument("-m", "--model_desc", type=str, default="gpt-2", choices=["gpt-2", "llama-3"], help="Model type, gpt-2|llama-3")
82+
args = parser.parse_args()
5883
download()
59-
tokenize()
84+
tokenize(args.model_desc)

dev/data/tinystories.py

+37-23
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,45 @@
11
"""
22
Downloads and tokenizes the TinyStories dataset.
33
- The download is from HuggingFace datasets.
4-
- The tokenization is GPT-2 tokenizer with tiktoken
4+
- The tokenization is using either GPT-2 or LLaMA 3 tokenizer.
55
66
The output is written to a newly created tinystories/ folder.
77
The script prints:
88
9+
For GPT-2:
10+
Number of shards: 50
911
Tokenizing val split...
10-
Saved 19043638 tokens to tinystories/TinyStories_val.bin
12+
writing 19,043,638 tokens to tinystories/TinyStories_val.bin
1113
Tokenizing train split...
12-
Saved 925653391 tokens to tinystories/TinyStories_train.bin
14+
writing 925,653,391 tokens to tinystories/TinyStories_train.bin
1315
14-
And runs in 1-2 minutes two depending on your internet
16+
For LLaMA 3:
17+
Number of shards: 50
18+
Tokenizing val split...
19+
writing 18,660,516 tokens to tinystories/TinyStories_val.bin
20+
Tokenizing train split...
21+
writing 907,021,844 tokens to tinystories/TinyStories_train.bin
22+
23+
And runs in few minutes two depending on your internet
1524
connection and computer. The .bin files are raw byte
16-
streams of int32 numbers indicating the token ids.
25+
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
1726
"""
1827

28+
import argparse
1929
import os
2030
import glob
2131
import json
2232
import random
23-
import requests
24-
from tqdm import tqdm
2533
from concurrent.futures import ProcessPoolExecutor, as_completed
34+
2635
import tiktoken
27-
import numpy as np
36+
from transformers import AutoTokenizer
37+
2838
from data_common import download_file, write_datafile
2939

3040
# -----------------------------------------------------------------------------
3141
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "tinystories")
3242

33-
enc = tiktoken.get_encoding("gpt2")
34-
encode = lambda s: enc.encode_ordinary(s)
35-
3643
def download():
3744
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
3845
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
@@ -63,10 +70,20 @@ def download():
6370
# data = json.load(f)
6471
# print(f"Example story:\n{data[0]}")
6572

66-
def process_shard(shard_index, shard_filename):
73+
def process_shard(shard_index, shard_filename, model_desc):
74+
if model_desc == "gpt-2":
75+
enc = tiktoken.get_encoding("gpt2")
76+
encode = lambda s: enc.encode_ordinary(s)
77+
eot = enc._special_tokens['<|endoftext|>'] # end of text token
78+
elif model_desc == "llama-3":
79+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
80+
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
81+
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
82+
else:
83+
raise ValueError(f"unknown model descriptor {model_desc}")
84+
6785
with open(shard_filename, "r") as f:
6886
data = json.load(f)
69-
eot = enc._special_tokens['<|endoftext|>'] # end of text token
7087
rng = random.Random(1337 + shard_index)
7188
rng.shuffle(data)
7289
all_tokens = []
@@ -78,7 +95,7 @@ def process_shard(shard_index, shard_filename):
7895
all_tokens.extend(tokens)
7996
return all_tokens
8097

81-
def tokenize():
98+
def tokenize(model_desc):
8299
# shard 0 will be the val split, rest is train
83100
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
84101
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
@@ -89,20 +106,17 @@ def tokenize():
89106
print(f"Tokenizing {split_name} split...")
90107
all_tokens = []
91108
with ProcessPoolExecutor() as executor:
92-
futures = [executor.submit(process_shard, shard_index, shard_filename)
109+
futures = [executor.submit(process_shard, shard_index, shard_filename, model_desc)
93110
for shard_index, shard_filename in enumerate(split_shards)]
94111
for future in as_completed(futures):
95112
all_tokens.extend(future.result())
96113

97114
split_filename = os.path.join(DATA_CACHE_DIR, f"TinyStories_{split_name}.bin")
98-
write_datafile(split_filename, all_tokens)
115+
write_datafile(split_filename, all_tokens, model_desc)
99116

100117
if __name__ == "__main__":
118+
parser = argparse.ArgumentParser(description="Tiny Stories dataset preprocessing")
119+
parser.add_argument("-m", "--model_desc", type=str, default="gpt-2", choices=["gpt-2", "llama-3"], help="Model type, gpt-2|llama-3")
120+
args = parser.parse_args()
101121
download()
102-
tokenize()
103-
104-
# Prints:
105-
# Tokenizing val split...
106-
# Saved 19043638 tokens to data/TinyStories_val.bin
107-
# Tokenizing train split...
108-
# Saved 925653391 tokens to data/TinyStories_train.bin
122+
tokenize(args.model_desc)

0 commit comments

Comments
 (0)