-
Notifications
You must be signed in to change notification settings - Fork 403
/
train.py
executable file
·131 lines (99 loc) · 3.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Trains, evaluates and saves the KittiSeg model.
-------------------------------------------------
The MIT License (MIT)
Copyright (c) 2017 Marvin Teichmann
More details: https://github.com/MarvinTeichmann/KittiSeg/blob/master/LICENSE
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import commentjson
import logging
import os
import sys
import collections
def dict_merge(dct, merge_dct):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
:param dct: dict onto which the merge is executed
:param merge_dct: dct merged into dct
:return: None
"""
for k, v in merge_dct.iteritems():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collections.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
# configure logging
if 'TV_IS_DEV' in os.environ and os.environ['TV_IS_DEV']:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
else:
logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
level=logging.INFO,
stream=sys.stdout)
# https://github.com/tensorflow/tensorflow/issues/2034#issuecomment-220820070
import numpy as np
flags = tf.app.flags
FLAGS = flags.FLAGS
sys.path.insert(1, 'incl')
import tensorvision.train as train
import tensorvision.utils as utils
flags.DEFINE_string('name', None,
'Append a name Tag to run.')
flags.DEFINE_string('project', None,
'Append a name Tag to run.')
flags.DEFINE_string('hypes', None,
'File storing model parameters.')
flags.DEFINE_string('mod', None,
'Modifier for model parameters.')
if 'TV_SAVE' in os.environ and os.environ['TV_SAVE']:
tf.app.flags.DEFINE_boolean(
'save', True, ('Whether to save the run. In case --nosave (default) '
'output will be saved to the folder TV_DIR_RUNS/debug, '
'hence it will get overwritten by further runs.'))
else:
tf.app.flags.DEFINE_boolean(
'save', True, ('Whether to save the run. In case --nosave (default) '
'output will be saved to the folder TV_DIR_RUNS/debug '
'hence it will get overwritten by further runs.'))
def main(_):
utils.set_gpus_to_use()
try:
import tensorvision.train
import tensorflow_fcn.utils
except ImportError:
logging.error("Could not import the submodules.")
logging.error("Please execute:"
"'git submodule update --init --recursive'")
exit(1)
if tf.app.flags.FLAGS.hypes is None:
logging.error("No hype file is given.")
logging.info("Usage: python train.py --hypes hypes/KittiClass.json")
exit(1)
with open(tf.app.flags.FLAGS.hypes, 'r') as f:
logging.info("f: %s", f)
hypes = commentjson.load(f)
utils.load_plugins()
if tf.app.flags.FLAGS.mod is not None:
import ast
mod_dict = ast.literal_eval(tf.app.flags.FLAGS.mod)
dict_merge(hypes, mod_dict)
if 'TV_DIR_RUNS' in os.environ:
os.environ['TV_DIR_RUNS'] = os.path.join(os.environ['TV_DIR_RUNS'],
'KittiSeg')
utils.set_dirs(hypes, tf.app.flags.FLAGS.hypes)
utils._add_paths_to_sys(hypes)
train.maybe_download_and_extract(hypes)
logging.info("Initialize training folder")
train.initialize_training_folder(hypes)
logging.info("Start training")
train.do_training(hypes)
if __name__ == '__main__':
tf.app.run()