Skip to content

Commit

Permalink
Add warning to trainstep output (#7779)
Browse files Browse the repository at this point in the history
* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update training_loop.py

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* Update test_training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_training_loop.py

* Update training_loop.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* escape regex

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

(cherry picked from commit 6a0d503)
  • Loading branch information
justusschock authored and SeanNaren committed Jun 8, 2021
1 parent db77223 commit 8c37c96
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
12 changes: 11 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -265,6 +265,16 @@ def _check_training_step_output(self, training_step_output):
if training_step_output.grad_fn is None:
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
elif self.trainer.lightning_module.automatic_optimization:
if not any((
isinstance(training_step_output, torch.Tensor),
(isinstance(training_step_output, Mapping)
and 'loss' in training_step_output), training_step_output is None
)):
raise MisconfigurationException(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)

def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
# give the PL module a result for logging
Expand Down
56 changes: 56 additions & 0 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

import pytest
import torch

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel


Expand Down Expand Up @@ -222,3 +225,56 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
else:
assert trainer.batch_idx == batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs


def test_should_stop_mid_epoch(tmpdir):
"""Test that training correctly stops mid epoch and that validation is still called at the right time"""

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.validation_called_at = None

def training_step(self, batch, batch_idx):
if batch_idx == 4:
self.trainer.should_stop = True
return super().training_step(batch, batch_idx)

def validation_step(self, *args):
self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step)
return super().validation_step(*args)

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=10,
limit_val_batches=1,
)
trainer.fit(model)

assert trainer.current_epoch == 0
assert trainer.global_step == 5
assert model.validation_called_at == (0, 4)


@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )])
def test_warning_invalid_trainstep_output(tmpdir, output):

class TestModel(BoringModel):

def training_step(self, batch, batch_idx):
return output

model = TestModel()

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
with pytest.raises(
MisconfigurationException,
match=re.escape(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)
):
trainer.fit(model)

0 comments on commit 8c37c96

Please sign in to comment.