Skip to content

Commit

Permalink
Merge pull request mmistakes#250 from pesser/pesser-dev-tests
Browse files Browse the repository at this point in the history
make tests pass
  • Loading branch information
theRealSuperMario authored Feb 14, 2020
2 parents 4f48de0 + d504dc3 commit b6cd0b2
Show file tree
Hide file tree
Showing 39 changed files with 156 additions and 1,827 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- Debug options: `debug/disable_integrations=True`, `debug/max_examples=5 batches`.
- Epoch and Batch step are restored.
- Added option to save checkpoint zero with `--ckpt_zero True`.
- Added support for `project` and `entity` in `integrations/wandb`.
- Logging figures using tensorboard now possible using log_tensorboard_figures.
- Added support for `eval_functor` in test mode.
- use `-p <rundir/configs/config.yaml>` as shortcut for `-b <rundir/configs/config.yaml> -p <rundir>`
Expand Down
8 changes: 5 additions & 3 deletions edflow/edsetup
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def create_edflow_project(project_name, replace: bool = False, **kwargs):
with open(source_config, "r+") as source_config_file:
source_config_dict = yaml.load(source_config_file, Loader=yaml.FullLoader)

path_keys = ["model_path", "dataset_path", "iterator_path"]
path_defaults = ["model.py", "dataset.py", "iterator.py"]
path_keys = ["model_path", "dataset_path", "dataset_path", "iterator_path"]
path_defaults = ["model.py", "dataset.py", "dataset.py", "iterator.py"]

destination_training_files = list()
for key, default in zip(path_keys, path_defaults):
Expand All @@ -61,8 +61,10 @@ def create_edflow_project(project_name, replace: bool = False, **kwargs):
for file, class_name in zip(training_files_to_module, training_classes)
]
training_parameters_dict = dict(
zip(["model", "dataset", "iterator"], full_address_to_class)
zip(["model", "train_dataset", "validation_dataset", "iterator"], full_address_to_class)
)
source_config_dict["datasets"]["train"] = training_parameters_dict.pop("train_dataset")
source_config_dict["datasets"]["validation"] = training_parameters_dict.pop("validation_dataset")
source_config_dict.update(training_parameters_dict)

with open(destination_config, "w+") as new_config_file:
Expand Down
4 changes: 3 additions & 1 deletion edflow/edsetup_files/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
dataset: dataset.Dataset
datasets:
train: dataset.Dataset
validation: dataset.Dataset
iterator: iterator.Iterator
model: model.Model

Expand Down
21 changes: 19 additions & 2 deletions edflow/hooks/checkpoint_hooks/lambda_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
save,
restore,
interval=None,
ckpt_zero=False,
modelname="model",
):
"""
Expand All @@ -29,11 +30,27 @@ def __init__(
self._save = save
self._restore = restore
self.interval = interval
self.ckpt_zero = ckpt_zero

os.makedirs(root_path, exist_ok=True)
self.savename = os.path.join(root_path, "{}-{{}}.ckpt".format(modelname))
self._active = False

def before_epoch(self, epoch):
"""
Parameters
----------
epoch :
Returns
-------
"""
if self.ckpt_zero and self.global_step_getter() == 0:
self.save(force_active=True)

def after_epoch(self, epoch):
"""
Expand Down Expand Up @@ -86,9 +103,9 @@ def at_exception(self, *args, **kwargs):
"""
self.save()

def save(self):
def save(self, force_active=False):
""" """
if self._active:
if self._active or force_active:
savename = self.savename.format(self.global_step_getter())
self._save(savename)
self.logger.info("Saved model to {}".format(savename))
Expand Down
46 changes: 34 additions & 12 deletions edflow/iterators/model_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import signal, sys
import signal, sys, math
from tqdm import tqdm, trange

from edflow.custom_logging import get_logger
Expand Down Expand Up @@ -157,26 +157,43 @@ def _iterate(self, batches):
validation_frequency = self.config.get(
"val_freq", self.config.get("log_freq", -1)
)
batches_per_epoch = 0 if epoch_hooks_only else len(batches["train"])
if "max_batches_per_epoch" in self.config:
batches_per_epoch = min(
batches_per_epoch, self.config["max_batches_per_epoch"]
)
num_epochs = 1 if epoch_hooks_only else self.num_epochs
start_epoch = (
0 if epoch_hooks_only else (self.get_global_step() // batches_per_epoch)
)
start_step = (
0 if epoch_hooks_only else (self.get_global_step() % batches_per_epoch)
)
for epoch_step in trange(
num_epochs, desc=desc_epoch, position=pos, dynamic_ncols=True
start_epoch,
num_epochs,
initial=start_epoch,
total=num_epochs,
desc=desc_epoch,
position=pos,
dynamic_ncols=True,
leave=False,
):
self._epoch_step = epoch_step

############# run one batch on each split until new epoch or max steps
batches["train"].reset()
self.run_hooks(epoch_step, before=True)

if epoch_hooks_only:
batches_per_epoch = 0
else:
batches_per_epoch = len(batches["train"])
if "max_batches_per_epoch" in self.config:
batches_per_epoch = min(
batches_per_epoch, self.config["max_batches_per_epoch"]
)
for batch_step in trange(
batches_per_epoch, desc=desc_batch, position=pos + 1, dynamic_ncols=True
start_step,
batches_per_epoch,
initial=start_step,
total=batches_per_epoch,
desc=desc_batch,
position=pos + 1,
dynamic_ncols=True,
leave=False,
):
self._batch_step = batch_step

Expand All @@ -202,6 +219,7 @@ def split_op():
if self.get_global_step() >= self.config.get("num_steps", float("inf")):
break
self.run_hooks(epoch_step, before=False)
start_step = 0

############# run one epoch on each split
# only continue a split as long as someone is retrieving results
Expand All @@ -212,8 +230,9 @@ def split_op():
tqdm_iterator = trange(
len(batches[split]),
desc=split,
position=pos + 2,
position=pos + 1,
dynamic_ncols=True,
leave=False,
)
for batch_step in tqdm_iterator:
self._batch_step = batch_step
Expand Down Expand Up @@ -256,6 +275,9 @@ def split_op():
break
self.run_hooks(epoch_step, before=False, epoch_hooks=True)

if self.get_global_step() >= self.config.get("num_steps", float("inf")):
break

def run(self, fetches, feed_dict):
"""Runs all fetch ops and stores the results.
Expand Down
14 changes: 13 additions & 1 deletion edflow/iterators/template_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, *args, **kwargs):
save=self.save,
restore=self.restore,
interval=set_default(self.config, "ckpt_freq", None),
ckpt_zero=set_default(self.config, "ckpt_zero", False),
)
# write checkpoints after epoch or when interrupted during training
if not self.config.get("test_mode", False):
Expand Down Expand Up @@ -74,7 +75,18 @@ def __init__(self, *args, **kwargs):
os.environ["WANDB_RUN_ID"] = ProjectManager.root.strip("/").replace(
"/", "-"
)
wandb.init(name=ProjectManager.root, config=self.config)
wandb_project = set_default(
self.config, "integrations/wandb/project", None
)
wandb_entity = set_default(
self.config, "integrations/wandb/entity", None
)
wandb.init(
name=ProjectManager.root,
config=self.config,
project=wandb_project,
entity=wandb_entity,
)

handlers = set_default(
self.config,
Expand Down
36 changes: 31 additions & 5 deletions edflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import datetime

from edflow.custom_logging import log, run
from edflow.util import get_obj_from_str
from edflow.util import get_obj_from_str, retrieve, set_value


def _save_config(config, prefix="config"):
Expand All @@ -25,6 +25,16 @@ def train(config, root, checkpoint=None, retrain=False, debug=False):
"""Run training. Loads model, iterator and dataset according to config."""
from edflow.iterators.batches import make_batches

# disable integrations in debug mode
if debug:
if retrieve(config, "debug/disable_integrations", default=True):
integrations = retrieve(config, "integrations", default=dict())
for k in integrations:
config["integrations"][k]["active"] = False
max_steps = retrieve(config, "debug/max_steps", default=5 * 2)
if max_steps > 0:
config["num_steps"] = max_steps

# backwards compatibility
if not "datasets" in config:
config["datasets"] = {"train": config["dataset"]}
Expand All @@ -48,8 +58,16 @@ def train(config, root, checkpoint=None, retrain=False, debug=False):
datasets[split].expand = True
logger.info("{} dataset size: {}".format(split, len(datasets[split])))
if debug:
logger.info("Monkey patching {} dataset __len__".format(split))
type(datasets[split]).__len__ = lambda self: 100
max_examples = retrieve(
config, "debug/max_examples", default=5 * config["batch_size"]
)
if max_examples > 0:
logger.info(
"Monkey patching {} dataset __len__ to {} examples".format(
split, max_examples
)
)
type(datasets[split]).__len__ = lambda self: max_examples

n_processes = config.get("n_data_processes", min(16, config["batch_size"]))
n_prefetch = config.get("n_prefetch", 1)
Expand Down Expand Up @@ -139,8 +157,16 @@ def test(config, root, checkpoint=None, nogpu=False, bar_position=0, debug=False
datasets[split].expand = True
logger.info("{} dataset size: {}".format(split, len(datasets[split])))
if debug:
logger.info("Monkey patching {} dataset __len__".format(split))
type(datasets[split]).__len__ = lambda self: 100
max_examples = retrieve(
config, "debug/max_examples", default=5 * config["batch_size"]
)
if max_examples > 0:
logger.info(
"Monkey patching {} dataset __len__ to {} examples".format(
split, max_examples
)
)
type(datasets[split]).__len__ = lambda self: max_examples

n_processes = config.get("n_data_processes", min(16, config["batch_size"]))
n_prefetch = config.get("n_prefetch", 1)
Expand Down
21 changes: 0 additions & 21 deletions examples/eval_hook/dataset.py

This file was deleted.

17 changes: 0 additions & 17 deletions examples/eval_hook/mnist_config.yaml

This file was deleted.

25 changes: 0 additions & 25 deletions examples/eval_hook/mnist_config_cb.yaml

This file was deleted.

Loading

0 comments on commit b6cd0b2

Please sign in to comment.