-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
59 lines (46 loc) · 1.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Main training script for squeezeDet
Usage:
train.py [options]
Options:
--data=<path> Path to root directory of data. [default: ./data/KITTI]
--batch_size=<size> Number of samples in a single batch [default: 20]
"""
import docopt
import mxnet as mx
import os.path
import time
from mxnet import metric
from squeezeDetMX.model import SqueezeDet
from squeezeDetMX.model import BboxError
from squeezeDetMX.model import ClassError
from squeezeDetMX.model import IOUError
from squeezeDetMX.utils import Reader
from squeezeDetMX.utils import build_module
from squeezeDetMX.utils import setup_logger
def main():
setup_logger()
arguments = docopt.docopt(__doc__)
data_root = arguments['--data']
batch_size = int(arguments['--batch_size'])
train_path = os.path.join(data_root, 'train.brick')
train_iter = Reader(train_path, batch_size=batch_size)
val_path = os.path.join(data_root, 'val.brick')
val_iter = Reader(val_path, batch_size=batch_size)
pre_iter = mx.io.PrefetchingIter([train_iter])
model = SqueezeDet()
module = build_module(model.error, 'squeezeDetMX', train_iter,
ctx=[mx.gpu(0), mx.gpu(1), mx.gpu(2), mx.gpu(3)])
try:
module.fit(
train_data=pre_iter,
eval_data=val_iter,
num_epoch=50,
batch_end_callback=mx.callback.Speedometer(batch_size, 10),
eval_metric=metric.CompositeEvalMetric(
metrics=[BboxError(), ClassError(), IOUError()]),
epoch_end_callback=mx.callback.do_checkpoint('squeezeDetMX', 1))
except KeyboardInterrupt:
module.save_params('squeezeDet-{}-9999.params'.format(
str(time.time())[-5:]))
if __name__ == '__main__':
main()