Skip to content

Commit

Permalink
Fixing tests (#936)
Browse files Browse the repository at this point in the history
* abs import

* rename test model

* update trainer

* revert test_step check

* move tags

* fix test_step

* clean tests

* fix template

* update dataset path

* fix parent order
  • Loading branch information
Borda authored Feb 25, 2020
1 parent 20d15c8 commit 5dd2afe
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 208 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- name: Cache datasets
uses: actions/cache@v1
with:
path: tests/models/mnist # This path is specific to Ubuntu
path: tests/datasets # This path is specific to Ubuntu
# Look to see if there is a cache hit for the corresponding requirements file
key: mnist-dataset

Expand Down
8 changes: 4 additions & 4 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def configure_optimizers(self):
return [optimizer], [scheduler]

def __dataloader(self, train):
# this is neede when you want some info about dataset before binding to trainer
self.prepare_data()
# init data generators
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
Expand All @@ -208,10 +210,8 @@ def __dataloader(self, train):
def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=True,
transform=transform, download=True)
dataset = MNIST(root=self.hparams.data_root, train=False,
transform=transform, download=True)
_ = MNIST(root=self.hparams.data_root, train=True,
transform=transform, download=True)

def train_dataloader(self):
log.info('Training data loader called.')
Expand Down
34 changes: 1 addition & 33 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import collections
import inspect
import logging as log
import csv
import os
import warnings
from abc import ABC, abstractmethod
Expand All @@ -13,7 +12,7 @@
from pytorch_lightning.core.decorators import data_loader
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
Expand Down Expand Up @@ -1316,34 +1315,3 @@ def get_tqdm_dict(self):
tqdm_dict['v_num'] = self.trainer.logger.version

return tqdm_dict


def load_hparams_from_tags_csv(tags_csv):
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()

tags = {}
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
for row in list(csv_reader)[1:]:
tags[row[0]] = convert(row[1])
ns = Namespace(**tags)
return ns


def convert(val):
constructors = [int, float, str]

if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False

for c in constructors:
try:
return c(val)
except ValueError:
pass
return val
37 changes: 37 additions & 0 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import csv
import logging as log
from argparse import Namespace


class ModelIO(object):

def on_load_checkpoint(self, checkpoint):
Expand Down Expand Up @@ -28,3 +34,34 @@ def on_hpc_load(self, checkpoint):
Hook to do whatever you need right before Slurm manager loads the model
:return:
"""


def load_hparams_from_tags_csv(tags_csv):
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()

tags = {}
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
for row in list(csv_reader)[1:]:
tags[row[0]] = convert(row[1])
ns = Namespace(**tags)
return ns


def convert(val):
constructors = [int, float, str]

if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False

for c in constructors:
try:
return c(val)
except ValueError:
pass
return val
2 changes: 1 addition & 1 deletion pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from torch import is_tensor

from pytorch_lightning.utilities.debugging import MisconfigurationException
from .base import LightningLoggerBase, rank_zero_only
from ..utilities.debugging import MisconfigurationException

logger = getLogger(__name__)

Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,19 @@ def reset_train_dataloader(self, model):
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset)
)
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
if self.is_iterable_dataloader(self.train_dataloader) and not isinstance(self.val_check_interval, int):
m = '''
When using an iterableDataset for `train_dataloader`,
`Trainer(val_check_interval)` must be an int.
An int k specifies checking validation every k training batches
'''
raise MisconfigurationException(m)

def is_iterable_dataloader(self, dataloader):
return (
EXIST_ITER_DATASET and isinstance(dataloader.dataset, IterableDataset)
)

def reset_val_dataloader(self, model):
"""
Dataloaders are provided by the model
Expand All @@ -200,9 +205,8 @@ def reset_val_dataloader(self, model):
self.num_val_batches = 0

# add samplers
for i, dataloader in enumerate(self.val_dataloaders):
dl = self.auto_add_sampler(dataloader, train=False)
self.val_dataloaders[i] = dl
self.val_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.val_dataloaders if dl]

# determine number of validation batches
# val datasets could be none, 1 or 2+
Expand All @@ -227,9 +231,8 @@ def reset_test_dataloader(self, model):
self.num_test_batches = 0

# add samplers
for i, dataloader in enumerate(self.test_dataloaders):
dl = self.auto_add_sampler(dataloader, train=False)
self.test_dataloaders[i] = dl
self.test_dataloaders = [self.auto_add_sampler(dl, train=False)
for dl in self.test_dataloaders if dl]

# determine number of test batches
if self.test_dataloaders is not None:
Expand Down
51 changes: 22 additions & 29 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ def reset_val_dataloader(self, model):
# this is just empty shell for code from other class
pass

def evaluate(self, model, dataloaders, max_batches, test=False):
def evaluate(self, model, dataloaders, max_batches, test_mode: bool = False):
"""Run evaluation code.
:param model: PT model
:param dataloaders: list of PT dataloaders
:param max_batches: Scalar
:param test: boolean
:param test_mode
:return:
"""
# enable eval mode
Expand Down Expand Up @@ -260,18 +260,14 @@ def evaluate(self, model, dataloaders, max_batches, test=False):
# -----------------
# RUN EVALUATION STEP
# -----------------
output = self.evaluation_forward(model,
batch,
batch_idx,
dataloader_idx,
test)
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)

# track outputs for collation
dl_outputs.append(output)

# batch done
if batch_idx % self.progress_bar_refresh_rate == 0:
if test:
if test_mode:
self.test_progress_bar.update(self.progress_bar_refresh_rate)
else:
self.val_progress_bar.update(self.progress_bar_refresh_rate)
Expand All @@ -286,7 +282,7 @@ def evaluate(self, model, dataloaders, max_batches, test=False):

# give model a chance to do something with the outputs (and method defined)
model = self.get_model()
if test and self.is_overriden('test_end'):
if test_mode and self.is_overriden('test_end'):
eval_results = model.test_end(outputs)
elif self.is_overriden('validation_end'):
eval_results = model.validation_end(outputs)
Expand All @@ -299,19 +295,19 @@ def evaluate(self, model, dataloaders, max_batches, test=False):

return eval_results

def run_evaluation(self, test=False):
def run_evaluation(self, test_mode: bool = False):
# when testing make sure user defined a test step
if test and not self.is_overriden('test_step'):
m = '''You called `.test()` without defining model's `.test_step()`.
Please define and try again'''
if test_mode and not self.is_overriden('test_step'):
m = "You called `.test()` without defining model's `.test_step()`." \
" Please define and try again"
raise MisconfigurationException(m)

# hook
model = self.get_model()
model.on_pre_performance_check()

# select dataloaders
if test:
if test_mode:
if self.reload_dataloaders_every_epoch or self.test_dataloaders is None:
self.reset_test_dataloader(model)

Expand All @@ -331,18 +327,15 @@ def run_evaluation(self, test=False):

# init validation or test progress bar
# main progress bar will already be closed when testing so initial position is free
position = 2 * self.process_position + (not test)
desc = 'Testing' if test else 'Validating'
pbar = tqdm(desc=desc, total=max_batches, leave=test, position=position,
position = 2 * self.process_position + (not test_mode)
desc = 'Testing' if test_mode else 'Validating'
pbar = tqdm(desc=desc, total=max_batches, leave=test_mode, position=position,
disable=not self.show_progress_bar, dynamic_ncols=True,
file=sys.stdout)
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar)

# run evaluation
eval_results = self.evaluate(self.model,
dataloaders,
max_batches,
test)
eval_results = self.evaluate(self.model, dataloaders, max_batches, test_mode)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
eval_results)

Expand All @@ -359,27 +352,27 @@ def run_evaluation(self, test=False):
model.on_post_performance_check()

# add model specific metrics
if not test:
if not test_mode:
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)

# close progress bar
if test:
if test_mode:
self.test_progress_bar.close()
else:
self.val_progress_bar.close()

# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
self.checkpoint_callback.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]

if test and len(self.test_dataloaders) > 1:
if test_mode and len(self.test_dataloaders) > 1:
args.append(dataloader_idx)

elif not test and len(self.val_dataloaders) > 1:
elif not test_mode and len(self.val_dataloaders) > 1:
args.append(dataloader_idx)

# handle DP, DDP forward
Expand All @@ -402,7 +395,7 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False
args[0] = batch

# CPU
if test:
if test_mode:
output = model.test_step(*args)
else:
output = model.validation_step(*args)
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
track_grad_norm: int = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
max_epochs: int = 1000,
Expand Down Expand Up @@ -681,7 +681,6 @@ def __init__(
self.train_dataloader = None
self.test_dataloaders = None
self.val_dataloaders = None
self.is_iterable_train_dataloader = False

# training state
self.model = None
Expand Down Expand Up @@ -1068,7 +1067,7 @@ def run_pretrain_routine(self, model: LightningModule):
if self.testing:
# only load test dataloader for testing
self.reset_test_dataloader(ref_model)
self.run_evaluation(test=True)
self.run_evaluation(test_mode=True)
return

# load the dataloaders
Expand All @@ -1087,15 +1086,17 @@ def run_pretrain_routine(self, model: LightningModule):
if not self.disable_validation and self.num_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm(desc='Validation sanity check',
total=self.num_sanity_val_steps * len(self.val_dataloaders),
leave=False, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True)
total=self.num_sanity_val_steps * len(self.val_dataloaders),
leave=False, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True)
self.main_progress_bar = pbar
# dummy validation progress bar
self.val_progress_bar = tqdm(disable=True)

eval_results = self.evaluate(model, self.val_dataloaders,
self.num_sanity_val_steps, False)
eval_results = self.evaluate(model,
self.val_dataloaders,
self.num_sanity_val_steps,
False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)

# close progress bars
Expand Down Expand Up @@ -1145,5 +1146,4 @@ def test(self, model: Optional[LightningModule] = None):
self.testing = True
if model is not None:
self.fit(model)
else:
self.run_evaluation(test=True)
self.run_evaluation(test_mode=True)
Loading

0 comments on commit 5dd2afe

Please sign in to comment.