Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove no return warning from val/test step #6139

Merged
merged 8 commits into from
Mar 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))


- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))


### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand All @@ -49,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))


- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166))


Expand Down
19 changes: 0 additions & 19 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
Expand Down Expand Up @@ -53,20 +48,12 @@ def forward(self, *inputs, **kwargs):
# ddp_plugin ``post_training_step`` hook
if not self.module.automatic_optimization:
trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")

elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")

elif trainer and (trainer.sanity_checking or trainer.validating):
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")

elif trainer and trainer.predicting:
output = self.module.predict(*inputs, **kwargs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
warn_if_output_is_none(output, "predict")

else:
output = self.module(*inputs, **kwargs)

Expand All @@ -76,12 +63,6 @@ def on_post_move_to_device(self):
pass


def warn_if_output_is_none(output: Any, method_name: str) -> None:
""" Warns user about which method returned None. """
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')


def unwrap_lightning_module(wrapped_model) -> LightningModule:
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING
and best_model_path is not None
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
if (
self.lightning_module.trainer.state == TrainerState.FITTING
and best_model_path is not None
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,7 @@ def get_evaluate_epoch_results(self):

# log results of evaluation
if (
self.trainer.state != TrainerState.FITTING
and self.trainer.evaluating
and self.trainer.is_global_zero
self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def get_evaluation_dataloaders(self):
self.trainer.reset_val_dataloader(model)
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches)
for val_batches in self.trainer.num_val_batches
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.warnings import WarningCache


class PredictLoop(object):
Expand All @@ -22,6 +23,7 @@ def __init__(self, trainer):
self.trainer = trainer
self.max_batches = None
self.num_dataloaders = None
self.warning_cache = WarningCache()

def on_trainer_init(self):
self.trainer.num_predict_batches = []
Expand Down Expand Up @@ -74,6 +76,10 @@ def predict(self, batch, batch_idx, dataloader_idx):

model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator.predict(args)

if predictions is None:
self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self._predictions[dataloader_idx].append(predictions)
self.trainer._progress_bar_callback.on_predict_batch_end(
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,7 @@ def test(
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
results = (
self.__evaluate_given_model(model, dataloaders=test_dataloaders)
if model_provided else
self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else
self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def no_warning_call(warning_type, match: Optional[str] = None):

try:
w = record.pop(warning_type)
if not ((match and match in w.text) or w):
if not (match and match in str(w.message)):
return
except AssertionError:
# no warning raised
Expand Down
66 changes: 9 additions & 57 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch.nn import DataParallel

from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import warning_cache
from pytorch_lightning.overrides.data_parallel import (
LightningParallelModule,
python_scalar_to_tensor,
Expand All @@ -20,12 +19,14 @@
LightningParallelModule,
LightningDistributedModule,
])
@pytest.mark.parametrize("stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
("predicting", "predict"),
])
@pytest.mark.parametrize(
"stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
("predicting", "predict"),
]
)
def test_lightning_wrapper_module_methods(wrapper_class, stage):
""" Test that the LightningWrapper redirects .forward() to the LightningModule methods. """
pl_module = MagicMock()
Expand All @@ -36,63 +37,14 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage):

prop, step = stage
pl_module.trainer.sanity_checking = False

for p in ("training", "testing", "validating", "predicting"):
setattr(pl_module.trainer, p, p == prop)

wrapped_module(batch, batch_idx)

getattr(pl_module, step).assert_called_with(batch, batch_idx)


@pytest.mark.parametrize("wrapper_class", [
LightningParallelModule,
LightningDistributedModule,
])
@pytest.mark.parametrize("stage", [
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
])
def test_lightning_wrapper_module_warn_none_output(wrapper_class, stage):
""" Test that the LightningWrapper module warns about forgotten return statement. """
warning_cache.clear()
pl_module = MagicMock()

prop, step = stage
pl_module.trainer.sanity_checking = False
for p in ("training", "testing", "validating", "predicting"):
setattr(pl_module.trainer, p, p == prop)

wrapped_module = wrapper_class(pl_module)

getattr(pl_module, step).return_value = None

with pytest.warns(UserWarning, match=f"Your {step} returned None"):
wrapped_module()


@pytest.mark.parametrize("wrapper_class", [
LightningParallelModule,
LightningDistributedModule,
])
def test_lightning_wrapper_module_no_warn(wrapper_class):
warning_cache.clear()
pl_module = MagicMock()

pl_module.trainer.sanity_checking = False
pl_module.trainer.training = False
pl_module.trainer.testing = False
pl_module.trainer.validating = False
pl_module.trainer.predicting = False

wrapped_module = wrapper_class(pl_module)

with pytest.warns(None) as record:
wrapped_module()
pl_module.assert_called()
assert not record


@pytest.mark.parametrize(
"inp,expected", [
[torch.tensor(1.0), torch.tensor([1.0])],
Expand Down
24 changes: 16 additions & 8 deletions tests/trainer/data_flow/test_train_loop_flow_scalar_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from tests.helpers.boring_model import BoringModel
from tests.helpers.deterministic_model import DeterministicModel
from tests.helpers.utils import no_warning_call


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
Expand Down Expand Up @@ -211,7 +212,8 @@ def backward(self, loss, optimizer, optimizer_idx):

def test_train_step_no_return(tmpdir):
"""
Tests that only training_step can be used
Tests that only training_step raises a warning when
nothing is returned in case of automatic_optimization
"""

class TestModel(BoringModel):
Expand All @@ -231,20 +233,26 @@ def validation_epoch_end(self, outputs):
assert len(outputs) == 0

model = TestModel()
trainer = Trainer(
trainer_args = dict(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
log_every_n_steps=1,
weights_summary=None,
fast_dev_run=2,
)

with pytest.warns(UserWarning, match=r'.*training_step returned None.*'):
trainer = Trainer(**trainer_args)

with pytest.warns(UserWarning, match=r'training_step returned None .*'):
trainer.fit(model)

assert model.training_step_called
assert model.validation_step_called

model = TestModel()
model.automatic_optimization = False
trainer = Trainer(**trainer_args)

with no_warning_call(UserWarning, match=r'training_step returned None .*'):
trainer.fit(model)


def test_training_step_no_return_when_even(tmpdir):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_trainer_state_while_running(tmpdir, extra_params):
trainer = Trainer(default_root_dir=tmpdir, **extra_params, auto_lr_find=True)

class TestModel(BoringModel):

def __init__(self, expected_state):
super().__init__()
self.expected_state = expected_state
Expand Down Expand Up @@ -78,6 +79,7 @@ def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
model = BoringModel()

class InterruptCallback(Callback):

def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

Expand Down
21 changes: 19 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,11 +1393,11 @@ def predict_dataloader(self):
return self._dataloaders


def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True):
def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True):

dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]

model = BoringModel()
model = model or BoringModel()
datamodule = TestLightningDataModule(dataloaders)

trainer = Trainer(
Expand All @@ -1422,6 +1422,23 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
assert results[0][0].shape == torch.Size([1, 2])


def test_trainer_predict_no_return(tmpdir):
"""
Test trainer.predict warns when nothing is returned
"""

class CustomBoringModel(BoringModel):

def predict(self, batch, batch_idx, dataloader_idx=None):
if (batch_idx + 1) % 2 == 0:
return

return super().predict(batch, batch_idx, dataloader_idx)

with pytest.warns(UserWarning, match='predict returned None'):
predict(tmpdir, None, None, 1, model=CustomBoringModel())


@pytest.mark.parametrize('datamodule', [False, True])
def test_trainer_predict_cpu(tmpdir, datamodule):
predict(tmpdir, None, None, 1, datamodule=datamodule)
Expand Down
Loading