-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
147 lines (122 loc) · 6.85 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
""" Config class for search/augment """
import argparse
import os
import genotypes as gt
from functools import partial
import torch
def get_parser(name):
""" make default formatted parser """
parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# print default value always
parser.add_argument = partial(parser.add_argument, help=' ')
return parser
def parse_gpus(gpus):
if gpus == 'all':
return list(range(torch.cuda.device_count()))
else:
return [int(s) for s in gpus.split(',')]
class BaseConfig(argparse.Namespace):
def print_params(self, prtf=print):
prtf("")
prtf("Parameters:")
for attr, value in sorted(vars(self).items()):
prtf("{}={}".format(attr.upper(), value))
prtf("")
def as_markdown(self):
""" Return configs as markdown format """
text = "|name|value| \n|-|-| \n"
for attr, value in sorted(vars(self).items()):
text += "|{}|{}| \n".format(attr, value)
return text
class SearchConfig(BaseConfig):
def build_parser(self):
parser = get_parser("Search config")
parser.add_argument('--name', required=True)
parser.add_argument('--dataset', required=True, help='CIFAR10 / MNIST / FashionMNIST')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--w_lr', type=float, default=0.025, help='lr for weights')
parser.add_argument('--w_lr_min', type=float, default=0.001, help='minimum lr for weights')
parser.add_argument('--w_momentum', type=float, default=0.9, help='momentum for weights')
parser.add_argument('--w_weight_decay', type=float, default=3e-4,
help='weight decay for weights')
parser.add_argument('--w_grad_clip', type=float, default=5.,
help='gradient clipping for weights')
parser.add_argument('--print_freq', type=int, default=50, help='print frequency')
parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
'`all` indicates use all gpus.')
parser.add_argument('--epochs', type=int, default=50, help='# of training epochs')
parser.add_argument('--init_channels', type=int, default=16)
parser.add_argument('--layers', type=int, default=8, help='# of layers')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--workers', type=int, default=4, help='# of workers')
parser.add_argument('--alpha_lr', type=float, default=3e-4, help='lr for alpha')
parser.add_argument('--alpha_weight_decay', type=float, default=1e-3,
help='weight decay for alpha')
return parser
def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
self.data_path = './data/'
self.path = os.path.join('searchs', self.name)
self.plot_path = os.path.join(self.path, 'plots')
self.gpus = parse_gpus(self.gpus)
class AugmentConfig(BaseConfig):
def build_parser(self):
parser = get_parser("Augment config")
parser.add_argument('--name', required=True)
parser.add_argument('--dataset', required=True, help='CIFAR100 / CIFAR10 / MNIST / FashionMNIST')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--lr', type=float, default=0.025, help='lr for weights')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float, default=5.,
help='gradient clipping for weights')
parser.add_argument('--print_freq', type=int, default=200, help='print frequency')
parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
'`all` indicates use all gpus.')
parser.add_argument('--epochs', type=int, default=600, help='# of training epochs')
parser.add_argument('--init_channels', type=int, default=36)
parser.add_argument('--layers', type=int, default=20, help='# of layers')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--workers', type=int, default=4, help='# of workers')
parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob')
parser.add_argument('--genotype', required=True, help='Cell genotype')
return parser
def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
self.data_path = './data/'
self.path = os.path.join('augments', self.name)
self.genotype = gt.from_str(self.genotype)
self.gpus = parse_gpus(self.gpus)
class ResnetConfig(BaseConfig):
def build_parser(self):
parser = get_parser("Augment config")
parser.add_argument('--name', required=True)
parser.add_argument('--dataset', required=True, help='CIFAR100 / CIFAR10 / MNIST / FashionMNIST')
parser.add_argument('--batch_size', type=int, default=96, help='batch size')
parser.add_argument('--lr', type=float, default=0.025, help='lr for weights')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--grad_clip', type=float, default=5.,
help='gradient clipping for weights')
parser.add_argument('--print_freq', type=int, default=200, help='print frequency')
parser.add_argument('--gpus', default='0', help='gpu device ids separated by comma. '
'`all` indicates use all gpus.')
parser.add_argument('--epochs', type=int, default=600, help='# of training epochs')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--workers', type=int, default=4, help='# of workers')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
return parser
def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
self.data_path = './data/'
self.path = os.path.join('augments', self.name)
self.genotype = gt.from_str(self.genotype)
self.gpus = parse_gpus(self.gpus)