Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass {fit,validate,test,predict} to setup() and teardown() #6386

Merged
merged 9 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


### Deprecated


Expand Down Expand Up @@ -98,6 +101,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))


- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))


- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))


- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))


Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def package_list_from_file(file):
doctest_global_setup = """
import importlib
import os
from typing import Optional
import torch
from torch import nn
import pytorch_lightning as pl
Expand Down
12 changes: 6 additions & 6 deletions docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
self.data_dir = data_dir
self.batch_size = batch_size

def setup(self, stage=None):
def setup(self, stage: Optional[str] = None):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.mnist_test = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
Expand Down Expand Up @@ -138,7 +138,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)

def setup(self, stage=None):
def setup(self, stage: Optional[str] = None):

# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
Expand Down Expand Up @@ -382,12 +382,12 @@ still ensures the method runs on the correct devices)

dm = MNISTDataModule()
dm.prepare_data()
dm.setup('fit')
dm.setup(stage='fit')

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

dm.setup('test')
dm.setup(stage='test')
trainer.test(datamodule=dm)

----------------
Expand All @@ -403,7 +403,7 @@ You can of course use DataModules in plain PyTorch code as well.
dm.prepare_data()

# splits/transforms
dm.setup('fit')
dm.setup(stage='fit')

# use data
for batch in dm.train_dataloader():
Expand All @@ -412,7 +412,7 @@ You can of course use DataModules in plain PyTorch code as well.
...

# lazy load test data
dm.setup('test')
dm.setup(stage='test')
for batch in dm.test_dataloader():
...

Expand Down
8 changes: 4 additions & 4 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ In this case, it's better to group the full definition of a dataset into a `Data
tokenize()
build_vocab()

def setup(self):
def setup(self, stage: Optional[str] = None):
# called on every GPU
vocab = load_vocab()
self.vocab_size = len(vocab)
Expand Down Expand Up @@ -310,8 +310,8 @@ An alternative to using a DataModule is to defer initialization of the models mo
download_data()
tokenize()

def setup(self, step):
# step is either 'fit' or 'test' 90% of the time not relevant
def setup(self, stage: Optional[str] = None):
# step is either 'fit', 'validate', 'test', or 'predict'. 90% of the time not relevant
data = load_data()
num_classes = data.classes
self.l1 = nn.Linear(..., num_classes)
Expand Down Expand Up @@ -598,7 +598,7 @@ In this method we do all the preparation we need to do once (instead of on every
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

def setup(self, stage):
def setup(self, stage: Optional[str] = None):
# transform
transform=transforms.Compose([transforms.ToTensor()])
mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/starter/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
MNIST(os.getcwd(), train=False, download=True)

# OPTIONAL, called for every GPU/machine (assigning state is OK)
def setup(self, stage):
def setup(self, stage: Optional[str] = None):
# transforms
transform=transforms.Compose([
transforms.ToTensor(),
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any, Dict
from typing import Any, Dict, Optional

from pytorch_lightning.core.lightning import LightningModule

Expand All @@ -33,12 +33,12 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul
"""Called before accelerator is being setup"""
pass

def setup(self, trainer, pl_module: LightningModule, stage: str) -> None:
"""Called when fit or test begins"""
def setup(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
"""Called when fit, validate, test, predict, or tune begins"""
pass

def teardown(self, trainer, pl_module: LightningModule, stage: str) -> None:
"""Called when fit or test ends"""
def teardown(self, trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
"""Called when fit, validate, test, predict, or tune ends"""
pass

def on_init_start(self, trainer) -> None:
Expand Down
58 changes: 39 additions & 19 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def __call__(cls, *args, **kwargs):
def track_data_hook_calls(fn):
"""A decorator that checks if prepare_data/setup have been called.

- When dm.prepare_data() is called, dm.has_prepared_data gets set to True
- When dm.setup('fit') is called, dm.has_setup_fit gets set to True
- When dm.setup('test') is called, dm.has_setup_test gets set to True
- When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True

Args:
fn (function): Function that will be tracked to see if it has been called.
Expand All @@ -77,15 +77,15 @@ def wrapped_fn(*args, **kwargs):
if fn.__name__ == "setup":

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)

if stage == "fit" or stage is None:
obj._has_setup_fit = True

if stage == "test" or stage is None:
obj._has_setup_test = True
if stage is None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_setup_{s}", True)
else:
setattr(obj, f"_has_setup_{stage}", True)

if fn.__name__ == "prepare_data":
obj._has_prepared_data = True
Expand Down Expand Up @@ -156,7 +156,9 @@ def __init__(
# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_validate = False
self._has_setup_test = False
self._has_setup_predict = False

@property
def train_transforms(self):
Expand Down Expand Up @@ -214,32 +216,50 @@ def size(self, dim=None) -> Union[Tuple, int]:
return self.dims

@property
def has_prepared_data(self):
"""Return bool letting you know if datamodule.prepare_data() has been called or not.
def has_prepared_data(self) -> bool:
"""Return bool letting you know if ``datamodule.prepare_data()`` has been called or not.

Returns:
bool: True if datamodule.prepare_data() has been called. False by default.
bool: True if ``datamodule.prepare_data()`` has been called. False by default.
"""
return self._has_prepared_data

@property
def has_setup_fit(self):
"""Return bool letting you know if datamodule.setup('fit') has been called or not.
def has_setup_fit(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='fit')`` has been called or not.

Returns:
bool: True if datamodule.setup('fit') has been called. False by default.
bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default.
"""
return self._has_setup_fit

@property
def has_setup_test(self):
"""Return bool letting you know if datamodule.setup('test') has been called or not.
def has_setup_validate(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='validate')`` has been called or not.

Returns:
bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default.
"""
return self._has_setup_validate

@property
def has_setup_test(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='test')`` has been called or not.

Returns:
bool: True if datamodule.setup('test') has been called. False by default.
bool: True if ``datamodule.setup(stage='test')`` has been called. False by default.
"""
return self._has_setup_test

@property
def has_setup_predict(self) -> bool:
"""Return bool letting you know if ``datamodule.setup(stage='predict')`` has been called or not.

Returns:
bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default.
"""
return self._has_setup_predict

@abstractmethod
def prepare_data(self, *args, **kwargs):
pass
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
class ModelHooks:
"""Hooks to be used in LightningModule."""

def setup(self, stage: str) -> None:
def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit and test.
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Args:
stage: either 'fit' or 'test'
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
carmocca marked this conversation as resolved.
Show resolved Hide resolved

Example::

Expand All @@ -53,12 +53,12 @@ def setup(stage):

"""

def teardown(self, stage: str) -> None:
def teardown(self, stage: Optional[str] = None) -> None:
"""
Called at the end of fit and test.
Called at the end of fit (train + validate), validate, test, predict, or tune.

Args:
stage: either 'fit' or 'test'
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
"""
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def on_fit_start(self) -> None:
Expand Down
26 changes: 13 additions & 13 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Dict, List, Type
from typing import Any, Callable, Dict, List, Type, Optional

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
Expand All @@ -29,18 +29,18 @@ class TrainerCallbackHookMixin(ABC):
callbacks: List[Callback] = []
lightning_module: LightningModule

def on_before_accelerator_backend_setup(self, model):
"""Called in the beginning of fit and test"""
def on_before_accelerator_backend_setup(self, model: LightningModule) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
carmocca marked this conversation as resolved.
Show resolved Hide resolved
for callback in self.callbacks:
callback.on_before_accelerator_backend_setup(self, model)

def setup(self, model, stage: str):
"""Called in the beginning of fit and test"""
def setup(self, model: LightningModule, stage: Optional[str]) -> None:
"""Called at the beginning of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.setup(self, model, stage)

def teardown(self, stage: str):
"""Called at the end of fit and test"""
def teardown(self, stage: Optional[str] = None) -> None:
"""Called at the end of fit (train + validate), validate, test, or predict, or tune."""
for callback in self.callbacks:
callback.teardown(self, self.lightning_module, stage)

Expand Down Expand Up @@ -124,15 +124,15 @@ def on_train_end(self):
for callback in self.callbacks:
callback.on_train_end(self, self.lightning_module)

def on_pretrain_routine_start(self, model):
"""Called when the train begins."""
def on_pretrain_routine_start(self) -> None:
"""Called when the pre-train routine begins."""
for callback in self.callbacks:
callback.on_pretrain_routine_start(self, model)
callback.on_pretrain_routine_start(self, self.lightning_module)

def on_pretrain_routine_end(self, model):
"""Called when the train ends."""
def on_pretrain_routine_end(self) -> None:
"""Called when the pre-train routine ends."""
for callback in self.callbacks:
callback.on_pretrain_routine_end(self, model)
callback.on_pretrain_routine_end(self, self.lightning_module)

def on_batch_start(self):
"""Called when the training batch begins."""
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import inspect
from abc import ABC
from typing import Optional

from pytorch_lightning.core.lightning import LightningModule

Expand All @@ -22,13 +23,14 @@ class TrainerModelHooksMixin(ABC):

lightning_module: LightningModule

def is_function_implemented(self, f_name, model=None):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def is_function_implemented(self, f_name: str, model: Optional[LightningModule] = None) -> bool:
# note: currently unused - kept as it is public
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's unused we need to add pragma no cover to exclude it from coverage
or test it anyway xD

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially removed it: #6386 (comment)

What do you think is the best option?

cc: @Borda

if model is None:
model = self.lightning_module
f_op = getattr(model, f_name, None)
return callable(f_op)

def has_arg(self, f_name, arg_name):
def has_arg(self, f_name: str, arg_name: str) -> bool:
model = self.lightning_module
f_op = getattr(model, f_name, None)
return arg_name in inspect.signature(f_op).parameters
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ class TrainerState(LightningEnum):
functions such as `trainer.fit()` and `trainer.test().

>>> # you can compare the type with a string
>>> TrainerState.FITTING == 'FITTING'
>>> TrainerState.FITTING == 'fit'
True
>>> # which is case insensitive
>>> TrainerState.FINISHED == 'finished'
>>> TrainerState.FINISHED == 'FINISHED'
True
"""
INITIALIZING = 'INITIALIZING' # trainer creation
FITTING = 'FITTING' # trainer.fit()
VALIDATING = 'VALIDATING' # trainer.validate()
TESTING = 'TESTING' # trainer.test()
PREDICTING = 'PREDICTING' # trainer.predict()
TUNING = 'TUNING' # trainer.tune()
FINISHED = 'FINISHED'
INTERRUPTED = 'INTERRUPTED'
INITIALIZING = 'initializing' # trainer creation
FITTING = 'fit' # trainer.fit()
VALIDATING = 'validate' # trainer.validate()
TESTING = 'test' # trainer.test()
PREDICTING = 'predict' # trainer.predict()
TUNING = 'tune' # trainer.tune()
FINISHED = 'finished'
INTERRUPTED = 'interrupted'

@property
def stopped(self) -> bool:
Expand Down
Loading