-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathboilerplate.py
147 lines (111 loc) · 4.28 KB
/
boilerplate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""TensorFlow Boilerplate main module."""
from collections import namedtuple
import json
import os
import sys
import tensorflow as tf
def Hyperparameters(value):
"""Turn a dict of hyperparameters into a nameduple.
This method will also check if `value` is a namedtuple, and if so, will return it
unchanged.
"""
# Don't transform `value` if it's a namedtuple.
# https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
t = type(value)
b = t.__bases__
if len(b) == 1 and b[0] == tuple:
fields = getattr(t, "_fields", None)
if isinstance(fields, tuple) and all(type(name) == str for name in fields):
return value
_Hyperparameters = namedtuple("Hyperparameters", value.keys())
return _Hyperparameters(**value)
class Model(tf.keras.Model):
"""Keras model with hyperparameter parsing and a few other utilities."""
default_hparams = {}
_methods = {}
def __init__(self, save_dir=None, method=None, **hparams):
super().__init__()
self._save_dir = save_dir
self._method = method
self.hparams = {**self.default_hparams, **hparams}
self._ckpt = None
self._mananger = None
# If the model's hyperparameters were saved, the saved values will be used as
# the default, but they will be overriden by hyperparameters passed to the
# constructor as keyword args.
hparams_path = os.path.join(save_dir, "hparams.json")
if os.path.isfile(hparams_path):
with open(hparams_path) as f:
self.hparams = {**json.load(f), **hparams}
else:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with open(hparams_path, "w") as f:
json.dump(self.hparams._asdict(), f, indent=4, sort_keys=True)
@property
def method(self):
return self._method
@property
def hparams(self):
return self._hparams
@hparams.setter
def hparams(self, value):
self._hparams = Hyperparameters(value)
@property
def save_dir(self):
return self._save_dir
def save(self):
"""Save the model's weights."""
if self._ckpt is None:
self._ckpt = tf.train.Checkpoint(model=self)
self._manager = tf.train.CheckpointManager(
self._ckpt, directory=self.save_dir, max_to_keep=1
)
self._manager.save()
def restore(self):
"""Restore the model's latest saved weights."""
if self._ckpt is None:
self._ckpt = tf.train.Checkpoint(model=self)
self._manager = tf.train.CheckpointManager(
self._ckpt, directory=self.save_dir, max_to_keep=1
)
self._ckpt.restore(self._manager.latest_checkpoint)
def make_summary_writer(self, dirname):
"""Create a TensorBoard summary writer."""
return tf.summary.create_file_writer(os.path.join(self.save_dir, dirname))
class DataLoader:
"""Data loader class akin to `Model`."""
default_hparams = {}
def __init__(self, method=None, **hparams):
self._method = method
self.hparams = {**self.default_hparams, **hparams}
@property
def method(self):
return self._method
@property
def hparams(self):
return self._hparams
@hparams.setter
def hparams(self, value):
self._hparams = Hyperparameters(value)
def __call__(self, *a, **kw):
return self.call(*a, **kw)
def call(self):
raise NotImplementedError
def runnable(f):
"""Mark a method as runnable from `run.py`."""
setattr(f, "_runnable", True)
return f
def default_export(cls):
"""Make the class the imported object of the module and compile its runnables."""
sys.modules[cls.__module__] = cls
for name, method in cls.__dict__.items():
if "_runnable" in dir(method) and method._runnable:
cls._methods[name] = method
return cls
def get_model(module_str):
"""Import the model in the given module string."""
return getattr(__import__(f"models.{module_str}"), module_str)
def get_data_loader(module_str):
"""Import the data loader in the given module string."""
return getattr(__import__(f"data_loaders.{module_str}"), module_str)