forked from NVlabs/dlinputs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dli-testsplit
executable file
·41 lines (31 loc) · 1.14 KB
/
dli-testsplit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/python3
import argparse
import imp
import random as pyr
import re
import dlinputs as dli
from dlinputs import shardwriter
parser = argparse.ArgumentParser("Split a tar file into test and training sets.")
parser.add_argument("-p", "--probability", type=float, default=0.1)
parser.add_argument("input")
parser.add_argument("training", nargs="?", default=None)
parser.add_argument("testing", nargs="?", default=None)
args = parser.parse_args()
if args.training is None:
args.training = re.sub(r"((?:-[0-9]+)?\.tgz)$", r"-train\1", args.input)
if args.testing is None:
args.testing = re.sub(r"((?:-[0-9]+)?\.tgz)$", r"-test\1", args.input)
print "splitting", args.input, "to", args.training, ",", args.testing
assert args.training != args.input
assert args.testing != args.input
training = shardwriter.TarWriter(args.training)
testing = shardwriter.TarWriter(args.testing)
for i, sample in enumerate(dli.ittarreader(args.input)):
if i % 1000 == 0: print i
key = sample["__key__"]
if pyr.random() < args.probability:
testing.write(key, sample)
else:
training.write(key, sample)
testing.finish()
training.finish()