-
Notifications
You must be signed in to change notification settings - Fork 13
/
prepro_tinyshakespeare.py
82 lines (71 loc) · 2.92 KB
/
prepro_tinyshakespeare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Downloads and tokenizes the TinyShakespeare dataset.
- The download is from Github.
- The tokenization is GPT-2 tokenizer with tiktoken
The output is written to a newly created data/ folder.
The script prints:
Saved 32768 tokens to data/tiny_shakespeare_val.bin
Saved 305260 tokens to data/tiny_shakespeare_train.bin
And runs in a few seconds depending on your internet
connection and computer. The .bin files are raw byte
streams of int32 numbers indicating the token ids.
"""
import os
import requests
from tqdm import tqdm
import tiktoken
import numpy as np
DATA_CACHE_DIR = "data"
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={'<|endoftext|>'})
def download_file(url: str, fname: str, chunk_size=1024):
"""Helper function to download a file from a given url"""
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
with open(fname, "wb") as file, tqdm(
desc=fname,
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
def download():
"""Downloads the TinyShakespeare dataset to DATA_CACHE_DIR"""
os.makedirs(DATA_CACHE_DIR, exist_ok=True)
# download the TinyStories dataset, unless it's already downloaded
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
if not os.path.exists(data_filename):
print(f"Downloading {data_url} to {data_filename}...")
download_file(data_url, data_filename)
else:
print(f"{data_filename} already exists, skipping download...")
def tokenize():
eot = enc._special_tokens['<|endoftext|>'] # end of text token
data_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare.txt")
text = open(data_filename, 'r').read()
# let's treat every person's statement in the dialog as a separate document
text = "<|endoftext|>" + text
text = text.replace('\n\n', '\n\n<|endoftext|>')
# encode the text
tokens = encode(text)
tokens_np = np.array(tokens, dtype=np.int32)
# let's take the first 32,768 tokens as the validation split (~10%)
val_tokens_np = tokens_np[:32768]
train_tokens_np = tokens_np[32768:]
# save to file
val_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_val.bin")
train_filename = os.path.join(DATA_CACHE_DIR, "tiny_shakespeare_train.bin")
with open(val_filename, "wb") as f:
f.write(val_tokens_np.tobytes())
with open(train_filename, "wb") as f:
f.write(train_tokens_np.tobytes())
# prints
print(f"Saved {len(val_tokens_np)} tokens to {val_filename}")
print(f"Saved {len(train_tokens_np)} tokens to {train_filename}")
if __name__ == "__main__":
download()
tokenize()