forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 2
/
prepro_tinystories.py
124 lines (107 loc) · 4.43 KB
/
prepro_tinystories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Downloads and tokenizes the TinyStories dataset.
- The download is from HuggingFace datasets.
- The tokenization is GPT-2 tokenizer with tiktoken
The output is written to a newly created data/ folder.
The script prints:
Tokenizing val split...
Saved 19043638 tokens to data/TinyStories_val.bin
Tokenizing train split...
Saved 925653391 tokens to data/TinyStories_train.bin
And runs in 1-2 minutes two depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
"""
import os
import glob
import json
import random
import requests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import tiktoken
import numpy as np
DATA_CACHE_DIR = "data"
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode_ordinary(s)
def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
def download():
"""Downloads the TinyStories dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# download the TinyStories dataset, unless it's already downloaded
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")
# unpack the tar.gz file into all the data shards (json files)
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
if not os.path.exists(data_dir):
os.makedirs(data_dir, exist_ok=True)
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}")
else:
print(f"{data_dir} already exists, skipping unpacking...")
# print a single example just for debugging and such
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
with open(shard_filenames[0], "r") as f:
data = json.load(f)
print("Download done.")
print(f"Number of shards: {len(shard_filenames)}")
#print(f"Example story:\n{data[0]}")
def process_shard(shard_index, shard_filename):
with open(shard_filename, "r") as f:
data = json.load(f)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
rng = random.Random(1337 + shard_index)
rng.shuffle(data)
all_tokens = []
for example in data:
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = encode(text)
all_tokens.append(eot)
all_tokens.extend(tokens)
return all_tokens
def tokenize():
# shard 0 will be the val split, rest is train
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
val_shards = [shard_filenames[0]]
train_shards = shard_filenames[1:]
for split_name, split_shards in [("val", val_shards), ("train", train_shards)]:
print(f"Tokenizing {split_name} split...")
all_tokens = []
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_shard, shard_index, shard_filename)
for shard_index, shard_filename in enumerate(split_shards)]
for future in as_completed(futures):
all_tokens.extend(future.result())
all_tokens_np = np.array(all_tokens, dtype=np.int32)
split_filename = os.path.join(DATA_CACHE_DIR, f"TinyStories_{split_name}.bin")
with open(split_filename, "wb") as f:
f.write(all_tokens_np.tobytes())
print(f"Saved {len(all_tokens_np)} tokens to {split_filename}")
if __name__ == "__main__":
download()
tokenize()
# Prints:
# Tokenizing val split...
# Saved 19043638 tokens to data/TinyStories_val.bin
# Tokenizing train split...
# Saved 925653391 tokens to data/TinyStories_train.bin