-
Notifications
You must be signed in to change notification settings - Fork 29
/
config.py
76 lines (63 loc) · 1.92 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
* Copyright (C) 2019 Zhonghui You
* If you are using this code in your research, please cite the paper:
* Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks, in NeurIPS 2019.
"""
import argparse
import json
from utils import dotdict
def make_as_dotdict(obj):
if type(obj) is dict:
obj = dotdict(obj)
for key in obj:
if type(obj[key]) is dict:
obj[key] = make_as_dotdict(obj[key])
return obj
def parse():
print('Parsing config file...')
parser = argparse.ArgumentParser(description="config")
parser.add_argument(
"--config",
type=str,
default="configs/base.json",
help="Configuration file to use"
)
cli_args = parser.parse_args()
with open(cli_args.config) as fp:
config = make_as_dotdict(json.loads(fp.read()))
print(json.dumps(config, indent=4, sort_keys=True))
return config
class Singleton(object):
_instance = None
def __new__(cls, *args, **kw):
if not cls._instance:
cls._instance = super(Singleton, cls).__new__(cls, *args, **kw)
return cls._instance
class Config(Singleton):
def __init__(self):
self._cfg = dotdict({})
try:
self._cfg = parse()
except:
pass
def __getattr__(self, name):
if name == '_cfg':
super().__setattr__(name)
else:
return self._cfg.__getattr__(name)
def __setattr__(self, name, val):
if name == '_cfg':
super().__setattr__(name, val)
else:
self._cfg.__setattr__(name, val)
def __delattr__(self, name):
return self._cfg.__delitem__(name)
def copy(self, new_config):
self._cfg = make_as_dotdict(new_config)
def raw(self):
return self._cfg
cfg = Config()
def parse_from_dict(d):
global cfg
assert type(d) == dict
cfg.copy(d)