diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 943864138e371..c54ea341b8779 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,7 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import numpy as np import torch @@ -265,6 +265,16 @@ def _check_training_step_output(self, training_step_output): if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") + elif self.trainer.lightning_module.automatic_optimization: + if not any(( + isinstance(training_step_output, torch.Tensor), + (isinstance(training_step_output, Mapping) + and 'loss' in training_step_output), training_step_output is None + )): + raise MisconfigurationException( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 94becf6488fc3..9f17c3baef6d3 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re + import pytest import torch from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -222,3 +225,56 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): else: assert trainer.batch_idx == batch_idx_ assert trainer.global_step == batch_idx_ * max_epochs + + +def test_should_stop_mid_epoch(tmpdir): + """Test that training correctly stops mid epoch and that validation is still called at the right time""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.validation_called_at = None + + def training_step(self, batch, batch_idx): + if batch_idx == 4: + self.trainer.should_stop = True + return super().training_step(batch, batch_idx) + + def validation_step(self, *args): + self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step) + return super().validation_step(*args) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=1, + ) + trainer.fit(model) + + assert trainer.current_epoch == 0 + assert trainer.global_step == 5 + assert model.validation_called_at == (0, 4) + + +@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )]) +def test_warning_invalid_trainstep_output(tmpdir, output): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + return output + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + with pytest.raises( + MisconfigurationException, + match=re.escape( + "In automatic optimization, `training_step` must either return a Tensor, " + "a dict with key 'loss' or None (where the step will be skipped)." + ) + ): + trainer.fit(model)