-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
140 lines (118 loc) · 5.19 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Generic script to run any method in a TensorFlow model."""
from argparse import ArgumentParser
import json
import os
import sys
import boilerplate as tfbp
if __name__ == "__main__":
if len(sys.argv) < 3:
print(
"Usage:\n New run: python run.py [method] [save_dir] [model] [data_loader]"
" [hyperparameters...]\n Existing run: python run.py [method] [save_dir] "
"[data_loader]? [hyperparameters...]",
file=sys.stderr,
)
exit(1)
# Avoid errors due to a missing `experiments` directory.
if not os.path.exists("experiments"):
os.makedirs("experiments")
# Dynamically parse arguments from the command line depending on the model and data
# loader provided. The `method` and `save_dir` arguments are always required.
parser = ArgumentParser()
parser.add_argument("method", type=str)
parser.add_argument("save_dir", type=str)
# If modules.json exists, the model and the data loader modules can be inferred from
# `save_dir`, and the data loader can be optionally changed from its default.
#
# Note that we need to use `sys` because we need to read the command line args to
# determine what to parse with argparse.
modules_json_path = os.path.join("experiments", sys.argv[2], "modules.json")
if os.path.exists(modules_json_path):
with open(modules_json_path) as f:
classes = json.load(f)
if len(sys.argv) >= 4 and not sys.argv[3].startswith("--"):
classes["data_loader"] = sys.argv[3]
Model = tfbp.get_model(classes["model"])
# The model shouldn't be provided for an existing run, but for convenience this
# error is handled for the user.
try:
DataLoader = tfbp.get_data_loader(classes["data_loader"])
parser.add_argument("data_loader", type=str)
except ModuleNotFoundError:
if len(sys.argv) < 5 or sys.argv[4].startswith("--"):
raise
# TODO: set up proper logging as part of the boilerplate.
print(
"Warning: model saved at",
os.path.join("experiments", sys.argv[2]),
f"already points to `models.{classes['model']}`, ignoring...",
file=sys.stderr,
)
classes["data_loader"] = sys.argv[4]
DataLoader = tfbp.get_data_loader(classes["data_loader"])
parser.add_argument("model", type=str)
parser.add_argument("data_loader", type=str)
else:
Model = tfbp.get_model(sys.argv[3])
DataLoader = tfbp.get_data_loader(sys.argv[4])
parser.add_argument("model", type=str)
parser.add_argument("data_loader", type=str)
if not os.path.exists(os.path.join("experiments", sys.argv[2])):
os.makedirs(os.path.join("experiments", sys.argv[2]))
with open(modules_json_path, "w") as f:
json.dump(
{"model": sys.argv[3], "data_loader": sys.argv[4]},
f,
indent=4,
sort_keys=True,
)
args = {}
saved_hparams = {}
hparams_json_path = os.path.join("experiments", sys.argv[2], "hparams.json")
if os.path.exists(hparams_json_path):
with open(hparams_json_path) as f:
saved_hparams = json.load(f)
for name, value in Model.default_hparams.items():
if name in saved_hparams:
value = saved_hparams[name]
args[name] = value
for name, value in DataLoader.default_hparams.items():
if name in saved_hparams:
value = saved_hparams[name]
args[name] = value
# Add a keyword argument to the argument parser for each hyperparameter.
for name, value in args.items():
# Make sure to correctly parse hyperparameters whose values are lists/tuples.
if type(value) in [list, tuple]:
if not len(value):
raise ValueError(
f"Cannot infer type of hyperparameter `{name}`. Please provide a "
"default value with nonzero length."
)
parser.add_argument(
f"--{name}", f"--{name}_", nargs="+", type=type(value[0]), default=value
)
else:
parser.add_argument(f"--{name}", type=type(value), default=value)
# Collect parsed hyperparameters.
FLAGS = parser.parse_args()
kwargs = {k: v for k, v in FLAGS._get_kwargs()}
for k in ["model", "save_dir", "data_loader"]:
if k in kwargs:
del kwargs[k]
# Instantiate model and data loader.
model = Model(os.path.join("experiments", FLAGS.save_dir), **kwargs)
data_loader = DataLoader(**kwargs)
# Restore the model's weights, or save them for a new run.
if os.path.isfile(os.path.join(model.save_dir, "checkpoint")):
model.restore()
else:
model.save()
# Run the specified model method.
if FLAGS.method not in Model._methods:
methods_str = "\n ".join(Model._methods.keys())
raise ValueError(
f"Model does not have a runnable method `{FLAGS.method}`. Methods available:"
f"\n {methods_str}"
)
Model._methods[FLAGS.method](model, data_loader)