Skip to content

Commit

Permalink
Merge pull request mmistakes#86 from pesser/config
Browse files Browse the repository at this point in the history
Config
  • Loading branch information
jhaux committed Jun 27, 2019
2 parents 01a7fb6 + 855422c commit 1e316d3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 73 deletions.
20 changes: 9 additions & 11 deletions edflow/edflow
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint # noqa
def update_config(config, options):
if options is not None:
for option in options:
config.update(yaml.load(option))
config.update(yaml.full_load(option))
# single format substitution, does not support nested structures
for k, v in config.items():
if isinstance(v, str):
Expand All @@ -47,13 +47,13 @@ def main(opt):
for base in opt.base:
print(base)
with open(base) as f:
base_config.update(yaml.load(f))
base_config.update(yaml.full_load(f))
print(base_config)
# get path to implementation
if opt.train:
with open(opt.train) as f:
config = base_config
config.update(yaml.load(f))
config.update(yaml.full_load(f))
update_config(config, opt.option)
impl = config["model"]
name = config.get("experiment_name", None)
Expand Down Expand Up @@ -105,16 +105,15 @@ def main(opt):
for base in opt.base:
print(base)
with open(base) as f:
base_config.update(yaml.load(f))
base_config.update(yaml.full_load(f))
print(base_config)
# get path to implementation
with open(opt.train) as f:
config = base_config
config.update(yaml.load(f))
config.update(yaml.full_load(f))
update_config(config, opt.option)

logger.info("Training config: {}".format(opt.train))
logger.info(yaml.dump(config))
logger.info("Training config: {}\n{}".format(opt.train, yaml.dump(config)))

train_process = mp.Process(
target=train,
Expand All @@ -135,14 +134,13 @@ def main(opt):
if opt.base is not None:
for base in opt.base:
with open(base) as f:
base_config.update(yaml.load(f))
base_config.update(yaml.full_load(f))
# get path to implementation
with open(eval_config) as f:
config = base_config
config.update(yaml.load(f))
config.update(yaml.full_load(f))
update_config(config, opt.option)
logger.info("Evaluation config: {}".format(eval_config))
logger.info(yaml.dump(config))
logger.info("Evaluation config: {}\n{}".format(eval_config, yaml.dump(eval_config)))
nogpu = len(processes) > 0 or opt.nogpu
bar_position = len(processes) + eval_idx
test_process = mp.Process(
Expand Down
2 changes: 1 addition & 1 deletion edflow/hooks/runtime_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def before_step(self, *args, **kwargs):
"""Checks if something changed and if yes runs the callback."""

try:
updates = yaml.load(open(self.ufile, "r"))
updates = yaml.full_load(open(self.ufile, "r"))

if self.last_updates is not None:
changes = {}
Expand Down
85 changes: 26 additions & 59 deletions edflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import yaml
import math
import datetime

# ignore broken pipe errors: https://www.quora.com/How-can-you-avoid-a-broken-pipe-error-on-Python
from signal import signal, SIGPIPE, SIG_DFL
Expand All @@ -13,6 +14,7 @@
import traceback

from edflow.custom_logging import init_project, get_logger, LogSingleton
from edflow.project_manager import ProjectManager as P


def get_obj_from_str(string):
Expand Down Expand Up @@ -67,6 +69,15 @@ def decorator(method):
return decorator


def _save_config(config, prefix="config"):
now = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
fname = prefix + "_" + now + ".yaml"
path = os.path.join(P.configs, fname)
with open(path, "w") as f:
f.write(yaml.dump(config))
return path


def train(args, job_queue, idx):
traceable_process(_train, args, job_queue, idx)

Expand All @@ -81,8 +92,7 @@ def _train(config, root, checkpoint=None, retrain=False):

LogSingleton().set_default("train")
logger = get_logger("train")
logger.info("Starting Training with config:")
logger.info(config)
logger.info("Starting Training.")

implementations = get_implementations_from_config(
config, ["model", "iterator", "dataset"]
Expand Down Expand Up @@ -136,6 +146,11 @@ def _train(config, root, checkpoint=None, retrain=False):
if retrain:
Trainer.reset_global_step()

# save current config
logger.info("Starting Training with config:\n{}".format(yaml.dump(config)))
cpath = _save_config(config, prefix="train")
logger.info("Saved config at {}".format(cpath))

logger.info("Iterating.")
Trainer.iterate(batches)

Expand All @@ -146,8 +161,7 @@ def _test(config, root, checkpoint=None, nogpu=False, bar_position=0):

LogSingleton().set_default("latest_eval")
logger = get_logger("test")
logger.info("Starting Evaluation with config")
logger.info(config)
logger.info("Starting Evaluation.")

if "test_batch_size" in config:
config["batch_size"] = config["test_batch_size"]
Expand Down Expand Up @@ -193,63 +207,16 @@ def _test(config, root, checkpoint=None, nogpu=False, bar_position=0):
else:
HBU_Evaluator.initialize()

# save current config
logger.info("Starting Evaluation with config:\n{}".format(yaml.dump(config)))
prefix = "eval"
if bar_position > 0:
prefix = prefix + str(bar_position)
cpath = _save_config(config, prefix=prefix)
logger.info("Saved config at {}".format(cpath))

logger.info("Iterating")
while True:
HBU_Evaluator.iterate(batches)
if not config.get("eval_forever", False):
break


def main(opt):
with open(opt.config) as f:
config = yaml.load(f)

P = init_project("logs")
logger = get_logger("main")
logger.info(opt)
logger.info(yaml.dump(config))
logger.info(P)

if opt.noeval:
train(config, P.train, opt.checkpoint, opt.retrain)
else:
train_process = mp.Process(
target=train, args=(config, P.train, opt.checkpoint, opt.retrain)
)
test_process = mp.Process(target=test, args=(config, P.latest_eval, True))

processes = [train_process, test_process]

try:
for p in processes:
p.start()

for p in processes:
p.join()

except KeyboardInterrupt:
logger.info("Terminating all processes")
for p in processes:
p.terminate()
finally:
logger.info("Finished")


if __name__ == "__main__":
default_log_dir = os.path.join(os.getcwd(), "log")

parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config")
parser.add_argument("--checkpoint", help="path to checkpoint to restore")
parser.add_argument(
"--noeval", action="store_true", default=False, help="only run training"
)
parser.add_argument(
"--retrain",
action="store_true",
default=False,
help="reset global_step to zero",
)

opt = parser.parse_args()
main(opt)
4 changes: 2 additions & 2 deletions edflow/project_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self, base=None, given_directory=None, code_root=".", postfix=None)
def setup(self):
"""Make all the directories."""

subdirs = ["code", "train", "eval", "ablation"]
subsubdirs = {"code": [], "train": ["checkpoints"], "eval": [], "ablation": []}
subdirs = ["code", "train", "eval", "configs"]
subsubdirs = {"code": [], "train": ["checkpoints"], "eval": [], "configs": []}

root = ProjectManager.root

Expand Down

0 comments on commit 1e316d3

Please sign in to comment.