-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
60 lines (46 loc) · 2.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from __future__ import absolute_import, division, print_function
from absl import app, flags, logging
from runs.train_sigua import train
from runs.train_sigua import train
from runs.preprocess import preprocess
import sys
import pprint
FLAGS = flags.FLAGS
# dataset
flags.DEFINE_string('dataset', 'MNIST', 'the name of dataset, available: MNIST (image), TREC (text)')
flags.DEFINE_string('datapath', 'data/', 'the dataset path to be downloaded')
flags.DEFINE_float('valid_ratio', '0.1', 'validation ratio out of total dataset')
# model parameters
flags.DEFINE_float('drop_rate', 0.5, 'dropout settings')
# training parameters
flags.DEFINE_integer('epochs', 30, 'the number of epochs for training')
flags.DEFINE_float('lr', 0.001, 'learning rate')
flags.DEFINE_integer('batch_size', 256, 'the number of batch for training')
flags.DEFINE_integer('num_class', 10, 'the number of class (category) in training data')
flags.DEFINE_integer('stop_patience', 3, 'the number of patience for early stopping')
# noisy parameters
flags.DEFINE_float('tau', 0.2, 'the estimated noise ratio')
flags.DEFINE_integer('num_gradual', 5, 'the number of gradual step (T_k = 5, 10, 15), default: 15')
flags.DEFINE_float('bad_weight', 0.001, 'the control factor for bad samples (-1.0 * bad_weight), default: 0.001')
flags.DEFINE_float('noise_prob', 0.2, 'noise probability in training data')
flags.DEFINE_string('noise_type', 'sym', 'noise type (sym, asym), default: sym')
# misc
flags.DEFINE_bool('gpu', True, '')
flags.DEFINE_string('run_mode', 'preprocess', 'current mode (train, preprocess, eval)')
flags.DEFINE_string('model', 'sigua', 'training model type (sigua, normal), default: sigua')
flags.DEFINE_string('save_dir', 'pretrained/', 'the path of directory for trained models')
# 94.4
def main(argv):
del argv # Unused.
logging.info('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info))
pprint.pprint(FLAGS.flag_values_dict(), indent=4)
if FLAGS.run_mode == 'train':
if FLAGS.model == 'normal':
from runs.train_normal import train
elif FLAGS.model == 'sigua':
from runs.train_sigua import train
train(FLAGS)
if FLAGS.run_mode == 'preprocess':
preprocess(FLAGS)
if __name__ == '__main__':
app.run(main)