Skip to content

Commit

Permalink
Fixes #2551 (#3858)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Oct 5, 2020
1 parent 97e62b3 commit f58c760
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 2 deletions.
32 changes: 32 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ def on_train_batch_end(
"""
# do something when the batch ends

def on_validation_model_eval(
self
) -> None:
"""
Sets the model to eval during the val loop
"""
self.eval()

def on_validation_model_train(
self
) -> None:
"""
Sets the model to train during the val loop
"""
self.train()

def on_validation_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
Expand Down Expand Up @@ -192,6 +208,22 @@ def on_test_batch_end(
"""
# do something when the batch ends

def on_test_model_eval(
self
) -> None:
"""
Sets the model to eval during the test loop
"""
self.eval()

def on_test_model_train(
self
) -> None:
"""
Sets the model to train during the test loop
"""
self.train()

def on_batch_start(self, batch: Any) -> None:
"""
Called in the training loop before anything happens for that batch.
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ def on_evaluation_start(self, *args, **kwargs):
else:
self.trainer.call_hook('on_validation_start', *args, **kwargs)

def on_evaluation_model_eval(self, *args, **kwargs):
model_ref = self.trainer.get_model()
if self.testing:
model_ref.on_test_model_eval()
else:
model_ref.on_validation_model_eval()

def on_evaluation_model_train(self, *args, **kwargs):
model_ref = self.trainer.get_model()
if self.testing:
model_ref.on_test_model_train()
else:
model_ref.on_validation_model_train()

def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,9 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):

# enable eval mode + no grads
model = self.get_model()
self.evaluation_loop.on_evaluation_model_eval()

model.zero_grad()
model.eval()
torch.set_grad_enabled(False)

# hook
Expand Down Expand Up @@ -615,7 +616,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
self.evaluation_loop.on_evaluation_epoch_end()

# enable train mode again
model.train()
self.evaluation_loop.on_evaluation_model_train()
torch.set_grad_enabled(True)

# hook
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions tests/trainer/model_hooks/test_model_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
from unittest import mock


@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval')
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train')
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval')
@mock.patch('pytorch_lightning.core.hooks.ModelHooks.on_test_model_train')
def test_eval_train_calls(test_train_mock, test_eval_mock, val_train_mock, val_eval_mock, tmpdir):
"""
Tests that only training_step can be used
"""
model = BoringModel()
model.validation_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
row_log_interval=1,
weights_summary=None,
)

trainer.fit(model)
trainer.test()

# sanity + 2 epochs
assert val_eval_mock.call_count == 3
assert val_train_mock.call_count == 3

# test is called only once
assert test_eval_mock.call_count == 1
assert test_train_mock.call_count == 1

0 comments on commit f58c760

Please sign in to comment.