Skip to content

Commit

Permalink
Refactor Trainer in advance of implementing Trainer.validate
Browse files Browse the repository at this point in the history
* Replace the `Trainer.testing` attribute with `Trainer.evaluating`, which is currently set to `'test'` if the top-level function called by the user was `Trainer.test(…)` and `None` otherwise. In the next PR, it will be set to `'validation’` when the user calls `validate(…)`.
* Update the other components to use the new attribute instead of `Trainer.testing`
* Disable the `EarlyStopping` and `ModelCheckpoint` callbacks when `evaluating`. This has no effect when evaluating on the test set, since they were already disabled, but it will be necessary for the validation set
* Rename a few other attributes of `Trainer` to clarify that they will be used by both `test(…)` and `validate(…)`
  • Loading branch information
EliaCereda committed Dec 2, 2020
1 parent add387c commit edb3e83
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 56 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def broadcast(self, obj, src=0):
return obj

def train_or_test(self):
if self.trainer.testing:
if self.trainer.evaluating:
results = self.trainer.run_test()
else:
results = self.trainer.train()
Expand Down Expand Up @@ -160,7 +160,7 @@ def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop

def setup_optimizers(self, model):
if self.trainer.testing is True:
if self.trainer.evaluating:
return

optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def on_load_checkpoint(self, checkpointed_state):
self.patience = checkpointed_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.running_sanity_check or trainer.evaluating:
return

self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.running_sanity_check or trainer.evaluating:
return

if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def save_checkpoint(self, trainer, pl_module):
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated
or self.last_global_step_saved == global_step # already saved at the last step
):
return
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule):
model: The model to check the configuration.
"""
if not self.trainer.testing:
if not self.trainer.evaluating:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'validation')
else:
# check test loop configuration
self.__verify_eval_loop_configuration(model, 'test')
# check evaluation loop configurations
self.__verify_eval_loop_configuration(model, self.trainer.evaluating)

def __verify_train_loop_configuration(self, model):
# -----------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def prepare_eval_loop_results(self):
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self, test_mode):
def get_evaluate_epoch_results(self):
if not self.trainer.running_sanity_check:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
Expand All @@ -274,11 +274,11 @@ def get_evaluate_epoch_results(self, test_mode):

self.prepare_eval_loop_results()

# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
# log results of evaluation
if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate:
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS')
pprint(results)
print('-' * 80)

Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def copy_trainer_model_properties(self, model):
m.use_ddp2 = self.trainer.use_ddp2
m.use_ddp = self.trainer.use_ddp
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
# Currently, the only users of m.testing appear to be DP and DDP,
# which use it to determine whether the model is currently inside
# the validation or test loop. For this reason it must check if
# trainer.evaluating is equal to "test" specifically.
m.testing = self.trainer.evaluating == 'test'
m.use_single_gpu = self.trainer.use_single_gpu
m.use_tpu = self.trainer.use_tpu
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import torch

import pytorch_lightning as pl
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.distributed import rank_zero_warn
Expand All @@ -22,7 +23,7 @@


class EvaluationLoop(object):
def __init__(self, trainer):
def __init__(self, trainer: 'pl.Trainer'):
self.trainer = trainer
self.testing = False
self.outputs = []
Expand All @@ -39,13 +40,15 @@ def on_trainer_init(self):
self.trainer.test_dataloaders = None
self.trainer.val_dataloaders = None
self.trainer.running_sanity_check = False
self.trainer.testing = False

# when .test() is called, it sets this
self.trainer.tested_ckpt_path = None
# .validate() sets this to 'validation' and .test() sets this to 'test'
self.trainer.evaluating = None

# when true, prints test results
self.trainer.verbose_test = True
# .validate() and .test() set this when they load a checkpoint
self.trainer.evaluated_ckpt_path = None

# when true, print evaluation results in .validate() and .test()
self.trainer.verbose_evaluate = True

def get_evaluation_dataloaders(self, max_batches):
# select dataloaders
Expand Down Expand Up @@ -216,7 +219,7 @@ def evaluation_epoch_end(self):

def log_epoch_metrics_on_evaluation_end(self):
# get the final loop results
eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing)
eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results()
return eval_loop_results

def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
Expand Down
82 changes: 49 additions & 33 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,6 @@ def fit(
# hook
self.data_connector.prepare_data(model)

# bookkeeping
# we reuse fit in .test() but change its behavior using this flag
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)

# ----------------------------
# SET UP TRAINING
# ----------------------------
Expand Down Expand Up @@ -720,33 +716,31 @@ def test(
datamodule: Optional[LightningDataModule] = None,
):
r"""
Separates from fit to make sure you never run on your test set until you want to.
Perform one evaluation epoch over the test set. It's separated from
fit to make sure you never run on your test set until you want to.
Args:
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the weights from the last epoch to test. Default to ``best``.
If ``None``, use the current weights of the model. Default to ``best``.
datamodule: A instance of :class:`LightningDataModule`.
model: The model to test.
test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
verbose: If True, prints the test results
model: The model to evaluate.
test_dataloaders: Either a single PyTorch DataLoader or a list of them,
specifying test samples.
verbose: If True, prints the test results.
Returns:
The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
The dictionary with final test results returned by test_epoch_end.
If test_epoch_end is not defined, the output is a list of the dictionaries
returned by test_step.
"""
# --------------------
# SETUP HOOK
# --------------------
self.verbose_test = verbose
self.verbose_evaluate = verbose

self.logger_connector.set_stage("test")

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
Expand All @@ -756,15 +750,15 @@ def test(
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')

if model is not None:
results = self.__test_given_model(model, test_dataloaders)
results = self.__evaluate_given_model(model, test_dataloaders, 'test')
else:
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
results = self.__evaluate_using_best_weights(ckpt_path, test_dataloaders, 'test')

self.teardown('test')

return results

def __test_using_best_weights(self, ckpt_path, test_dataloaders):
def __evaluate_using_best_weights(self, ckpt_path, test_dataloaders, stage: str):
model = self.get_model()

# if user requests the best checkpoint but we don't have it, error
Expand Down Expand Up @@ -796,40 +790,56 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

# run tests
self.tested_ckpt_path = ckpt_path
self.testing = True
os.environ['PL_TESTING_MODE'] = '1'
self.evaluating = stage
self.evaluated_ckpt_path = ckpt_path
self.model = model
results = self.fit(model)
self.testing = False
del os.environ['PL_TESTING_MODE']
self.evaluating = None

# teardown
if self.is_function_implemented('teardown'):
model_ref = self.get_model()
model_ref.teardown('test')
model_ref.teardown(stage)

return results

def __test_given_model(self, model, test_dataloaders):
def __evaluate_given_model(self, model, test_dataloaders, stage: str):

# attach data
if test_dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

# run test
# sets up testing so we short circuit to eval
self.testing = True
self.evaluating = stage
self.model = model
results = self.fit(model)
self.testing = False
self.evaluating = None

# teardown
if self.is_function_implemented('teardown'):
model.teardown('test')
model.teardown(stage)

return results

@property
def testing(self):
warnings.warn(
'Trainer.testing has been deprecated in v1.1 and will be removed '
'in v1.3, use Trainer.evaluating instead.',
DeprecationWarning, stacklevel=2
)
return bool(self.evaluating)

@property
def tested_ckpt_path(self):
warnings.warn(
'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path '
'in v1.1 and will be removed in v1.3.',
DeprecationWarning, stacklevel=2
)
return self.evaluated_ckpt_path

def tune(
self,
model: LightningModule,
Expand All @@ -856,11 +866,17 @@ def tune(

def call_setup_hook(self, model):
# call setup after the ddp process has connected
stage_name = 'test' if self.testing else 'fit'
stage_name = self.evaluating or 'fit'

if self.datamodule is not None:
called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
called = {
None: self.datamodule.has_setup_fit,
'test': self.datamodule.has_setup_test,
}[self.evaluating]

if not called:
self.datamodule.setup(stage_name)

self.setup(model, stage_name)
model.setup(stage_name)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def setup_training(self, model: LightningModule):
ref_model.on_pretrain_routine_start()

# print model summary
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.evaluating:
if self.trainer.weights_summary in ModelSummary.MODES:
ref_model.summarize(mode=self.trainer.weights_summary)
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,12 +728,12 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
trainer.test(ckpt_path=ckpt_path)
else:
trainer.test(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path
assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and
# use the weights from the end of training
trainer.test(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path is None
assert trainer.evaluated_ckpt_path is None
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
Expand All @@ -746,7 +746,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
].absolute()
)
trainer.test(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path == ckpt_path
assert trainer.evaluated_ckpt_path == ckpt_path


def test_disabled_training(tmpdir):
Expand Down

0 comments on commit edb3e83

Please sign in to comment.