forked from luyiyun/NormAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
154 lines (141 loc) · 5.33 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import json
import argparse
class Config:
def __init__(self):
self.parser = argparse.ArgumentParser()
# task
self.parser.add_argument(
"--task", default="train",
help=("task, train model (train, default) or "
"remove batch effects (remove)")
)
# dataset
self.parser.add_argument(
"--meta_data", help="the path of metabolomics data"
)
self.parser.add_argument(
"--sample_data", help="the path of sample information"
)
self.parser.add_argument(
'-td', '--train_data', default='all',
help="the training data, subject or all (default)"
)
# save results
self.parser.add_argument(
'-s', '--save', default='./save',
help='the path to save results, default ./save'
)
# architecture
self.parser.add_argument(
'--ae_encoder_units', default=[1000, 1000], type=int, nargs="+",
help="the hidden units of encoder, default 1000, 1000"
)
self.parser.add_argument(
'--ae_decoder_units', default=[1000, 1000], type=int, nargs="+",
help="the hidden units of decoder, default 1000, 1000"
)
self.parser.add_argument(
"--disc_b_units", default=[250, 250], type=int, nargs="+",
help="the hidden units of disc_b, default 250, 250"
)
self.parser.add_argument(
"--disc_o_units", default=[250, 250], type=int, nargs="+",
help="the hidden units of disc_b, default 250, 250"
)
self.parser.add_argument(
'--bottle_num', type=int, default=500,
help="the number of bottle neck units, default 500"
)
self.parser.add_argument(
'--dropouts', default=(0.3, 0.1, 0.3, 0.3), type=float, nargs=4,
help=("the dropout rates of encoder, decoder, disc_b, disc_o,"
"default 0.3, 0.1, 0.3, 0.3")
)
# regularization
self.parser.add_argument(
'--lambda_b', type=float, default=1.0,
help="the weight of adversarial loss for batch labels, default 1"
)
self.parser.add_argument(
'--lambda_o', type=float, default=1.0,
help=("the weight of adversarial loss for injection order,"
" default 1")
)
# training
self.parser.add_argument(
"--lr_rec", type=float, default=0.0002,
help="the learning rate of AE training, default 0.0002"
)
self.parser.add_argument(
"--lr_disc_b", type=float, default=0.005,
help="the leanring rate of disc_b training, default 0.005"
)
self.parser.add_argument(
"--lr_disc_o", type=float, default=0.0005,
help="the leanring rate of disc_o training, default 0.0005"
)
self.parser.add_argument(
'-e', '--epoch', default=(1000, 10, 700), type=int, nargs=3,
help=("ae pretrain, disc pretrain, "
"iteration train epochs,default (1000, 10, 700)")
)
self.parser.add_argument(
'--use_batch_for_order', default=True, type=bool,
help="if compute rank loss with batch ?, default True"
)
self.parser.add_argument(
'-bs', '--batch_size', default=64, type=int,
help='batch size,default 64'
)
# other
self.parser.add_argument(
"--load", default=None, type=str,
help="load trained models, default None"
)
self.parser.add_argument(
'--visdom_env', default='main',
help="if use visdom, it is the env name,default main"
)
self.parser.add_argument(
'--visdom_port', default=8097, type=int,
help="if use visdom, it is the port, default 8097"
)
self.parser.add_argument(
'-nw', '--num_workers', default=12, type=int,
help='the number of multi cores, default 12'
)
self.parser.add_argument(
"--use_log", action="store_true",
help="use logrithm?"
)
self.parser.add_argument(
"--use_batch", default=None, type=int,
help="use part of batches? default None"
)
self.parser.add_argument(
"--sample_size", default=None, type=float,
help="use size of part of samples? default None"
)
self.parser.add_argument(
"--random_seed", default=1234, type=int,
help="random seed, default 1234."
)
self.parser.add_argument(
"--device", default=None, type=str,
choices=(None, "CPU", "GPU"), help="device"
)
self.args = self.parser.parse_args()
def init(self):
if self.args.task == "remove" and self.args.load is None:
raise ValueError("load cannot be None for remove task.")
return self.args
def save(self, fname):
self.save_dict = self.args.__dict__
with open(fname, 'w') as f:
json.dump(self.save_dict, f)
def show(self):
print('')
print('the settings of training:')
for k, v in self.args.__dict__.items():
print('%s: %s' % (k, str(v)))
print('')