-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainTestSplit.py
34 lines (27 loc) · 998 Bytes
/
trainTestSplit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import random
def make_test_set(train_dir, test_dir):
"""Makes a test set from the given train set.
Args:
train_dir: The directory containing the train set.
test_dir: The directory to store the test set.
"""
if not os.path.exists(test_dir):
os.mkdir(test_dir)
for digit in range(10):
digit_dir = os.path.join(train_dir, str(digit))
test_digit_dir = os.path.join(test_dir, str(digit))
if not os.path.exists(test_digit_dir):
os.mkdir(test_digit_dir)
files = os.listdir(digit_dir)
for i, file in enumerate(files):
if i % 5 == 0:
random_index = random.randint(0, len(files) - 1)
file_to_move = files[random_index]
os.system(f"cp {os.path.join(digit_dir, file_to_move)} {os.path.join(test_digit_dir, file_to_move)}")
print(f"Moved {file_to_move} to {test_digit_dir}")
if __name__ == "__main__":
train_dir = "data/train"
test_dir = "data/test"
make_test_set(train_dir, test_dir)
print("Done!")