-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutil.py
142 lines (107 loc) · 4.6 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import numpy as np
import os
import torch
import torch.nn as nn
from trixi.util import Config, GridSearch
class ConvModule(nn.Module):
"""Utility Module for more convenient weight initialization"""
conv_types = (nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d)
@classmethod
def is_conv(cls, op):
if type(op) == type and issubclass(op, cls.conv_types):
return True
elif type(op) in cls.conv_types:
return True
else:
return False
def __init__(self, *args, **kwargs):
super(ConvModule, self).__init__(*args, **kwargs)
def init_weights(self, init_fn, *args, **kwargs):
class init_(object):
def __init__(self):
self.fn = init_fn
self.args = args
self.kwargs = kwargs
def __call__(self, module):
if ConvModule.is_conv(type(module)):
module.weight = self.fn(module.weight, *self.args, **self.kwargs)
_init_ = init_()
self.apply(_init_)
def init_bias(self, init_fn, *args, **kwargs):
class init_(object):
def __init__(self):
self.fn = init_fn
self.args = args
self.kwargs = kwargs
def __call__(self, module):
if ConvModule.is_conv(type(module)) and module.bias is not None:
module.bias = self.fn(module.bias, *self.args, **self.kwargs)
_init_ = init_()
self.apply(_init_)
def get_default_experiment_parser():
parser = argparse.ArgumentParser()
parser.add_argument("base_dir", type=str, help="Working directory for experiment.")
parser.add_argument("-c", "--config", type=str, default=None, help="Path to a config file.")
parser.add_argument("-v", "--visdomlogger", action="store_true", help="Use visdomlogger.")
parser.add_argument("-dc", "--default_config", type=str, default="DEFAULTS", help="Select a default Config")
parser.add_argument("--resume", type=str, default=None, help="Path to resume from")
parser.add_argument("-ir", "--ignore_resume_config", action="store_true", help="Ignore Config in experiment we resume from.")
parser.add_argument("--test", action="store_true", help="Run test instead of training")
parser.add_argument("--grid", type=str, help="Path to a config for grid search")
parser.add_argument("-s", "--skip_existing", action="store_true", help="Skip configs fpr which an experiment exists")
parser.add_argument("-m", "--mods", type=str, nargs="+", default=None, help="Mods are Config stubs to update only relevant parts for a certain setup.")
return parser
def run_experiment(experiment, configs, args, mods=None, **kwargs):
config = Config(file_=args.config) if args.config is not None else Config()
config.update_missing(configs[args.default_config])
if args.mods is not None:
for mod in args.mods:
config.update(mods[mod])
config = Config(config=config, update_from_argv=True)
# GET EXISTING EXPERIMENTS TO BE ABLE TO SKIP CERTAIN CONFIGS
if args.skip_existing:
existing_configs = []
for exp in os.listdir(args.base_dir):
try:
existing_configs.append(Config(file_=os.path.join(args.base_dir, exp, "config", "config.json")))
except Exception as e:
pass
if args.grid is not None:
grid = GridSearch().read(args.grid)
else:
grid = [{}]
for combi in grid:
config.update(combi)
if args.skip_existing:
skip_this = False
for existing_config in existing_configs:
if existing_config.contains(config):
skip_this = True
break
if skip_this:
continue
loggers = {}
if args.visdomlogger:
loggers["visdom"] = ("visdom", {}, 1)
exp = experiment(config=config,
base_dir=args.base_dir,
resume=args.resume,
ignore_resume_config=args.ignore_resume_config,
loggers=loggers,
**kwargs)
if not args.test:
exp.run()
else:
exp.run_test()
def set_seeds(seed, cuda=True):
if not hasattr(seed, "__iter__"):
seed = (seed, seed, seed)
np.random.seed(seed[0])
torch.manual_seed(seed[1])
if cuda: torch.cuda.manual_seed_all(seed[2])