Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RunningStage usage in advance of implementing Trainer.validate() #4945

Merged
merged 39 commits into from
Mar 6, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
104d796
Update code
carmocca Feb 24, 2021
23c2d3b
More property updates
carmocca Feb 24, 2021
dd09134
Move properties. Introduce trainer._fitting
carmocca Feb 25, 2021
68469ed
Use trainer.fitting
carmocca Feb 25, 2021
77df20b
Merge branch 'master' into feature/trainer-validate-1
carmocca Feb 25, 2021
89aa994
Fix reset dataloaders
carmocca Feb 25, 2021
3c6e99c
Unused code
carmocca Feb 25, 2021
9ba12d7
RunningStage.SANITY_CHECKING
carmocca Feb 25, 2021
f3d16a4
Use setters
carmocca Feb 25, 2021
0697c3e
Fix bugs
carmocca Feb 25, 2021
39686ae
Fix bugs
carmocca Feb 25, 2021
0ed386b
TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}
carmocca Feb 25, 2021
18e851c
Fix bugs
carmocca Feb 25, 2021
dec84ec
Fix bugs
carmocca Feb 25, 2021
0b21f4d
Fix tests
carmocca Feb 25, 2021
94707ba
Merge branch 'master' into feature/trainer-validate-1
carmocca Feb 25, 2021
8cdac8e
Update CHANGELOG. Add deprecation warning. Fix tests
carmocca Feb 25, 2021
73916f4
Unused imports
carmocca Feb 25, 2021
6c62eec
Optional trainer
carmocca Feb 25, 2021
0a211a2
More deprecation. More refactoring
carmocca Feb 25, 2021
6c6752c
Correct version
carmocca Feb 25, 2021
34ae418
Use properties
carmocca Feb 26, 2021
24f1c1e
Address comments
carmocca Mar 1, 2021
13483f5
Merge branch 'master' into feature/trainer-validate-1
carmocca Mar 3, 2021
1e5d84d
flake8
carmocca Mar 3, 2021
dccf603
Missed renamings
carmocca Mar 3, 2021
56b4494
Typo
carmocca Mar 3, 2021
94e93ad
is -> ==
carmocca Mar 3, 2021
6c53cdc
Also for tests
carmocca Mar 3, 2021
99db197
Merge branch 'master' into feature/trainer-validate-1
carmocca Mar 4, 2021
1b0709d
Merge branch 'master' into feature/trainer-validate-1
carmocca Mar 5, 2021
b64b46e
Typo
carmocca Mar 5, 2021
7d42798
Address @tchaton's comments
carmocca Mar 5, 2021
7a3f8cd
PEP8
carmocca Mar 5, 2021
c0ef3fa
Correct property
carmocca Mar 5, 2021
63c9493
Update CHANGELOG
carmocca Mar 5, 2021
10f7f21
Apply suggestions from code review
carmocca Mar 6, 2021
d1dc4c9
Update pytorch_lightning/trainer/trainer.py
carmocca Mar 6, 2021
45a010f
Remove called sanity check
carmocca Mar 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,30 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


### Changed


- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


### Deprecated


- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Deprecated `trainer.tested_ckpt_path` in favor of `trainer.evaluated_ckpt_path` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


### Removed

- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
Expand Down Expand Up @@ -81,8 +82,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
def start_training(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_training(trainer)

def start_testing(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_testing(trainer)
def start_evaluating(self, trainer: 'Trainer') -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.training_type_plugin.start_evaluating(trainer)

def start_predicting(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)
Expand Down Expand Up @@ -330,7 +331,7 @@ def setup_optimizers(self, trainer: 'Trainer') -> None:
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.testing:
if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING):
return
carmocca marked this conversation as resolved.
Show resolved Hide resolved
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
Expand Down Expand Up @@ -426,7 +427,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
@property
def results(self) -> Any:
"""
The results of the last training/testing run will be cached within the training type plugin.
The results of the last run will be cached within the training type plugin.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
return self.training_type_plugin.results
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,13 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.patience = callback_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
from pytorch_lightning.trainer.states import TrainerState
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if trainer.state is not TrainerState.FITTING or trainer.sanity_checking:
return

self._run_early_stopping_check(trainer, pl_module)
self._run_early_stopping_check(trainer)

def _run_early_stopping_check(self, trainer, pl_module):
def _run_early_stopping_check(self, trainer):
"""
Checks whether the early stopping condition is met
and if so tells the trainer to stop the training.
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ def save_checkpoint(self, trainer, pl_module):
epoch = trainer.current_epoch
global_step = trainer.global_step

from pytorch_lightning.trainer.states import TrainerState
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state is not TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or self._last_global_step_saved == global_step # already saved at the last step
):
return
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down Expand Up @@ -412,7 +411,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if not trainer.running_sanity_check:
if trainer.sanity_checking:
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else:
self._update_bar(self.main_progress_bar) # fill up remaining
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, self.total_val_batches)
Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -44,9 +44,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

if TYPE_CHECKING:
from pytorch_lightning.trainer.states import RunningStage


class LightningModule(
ABC,
Expand Down Expand Up @@ -171,10 +168,6 @@ def automatic_optimization(self) -> bool:
"""
return self._automatic_optimization

carmocca marked this conversation as resolved.
Show resolved Hide resolved
@property
def running_stage(self) -> Optional["RunningStage"]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return self.trainer._running_stage if self.trainer else None

@automatic_optimization.setter
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization
Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.warnings import WarningCache

Expand All @@ -43,28 +42,28 @@ def __init__(self, pl_module: LightningModule):
self.module = pl_module

def forward(self, *inputs, **kwargs):
running_stage = self.module.running_stage
trainer = self.module.trainer

if running_stage == RunningStage.TRAINING:
if trainer and trainer.training:
output = self.module.training_step(*inputs, **kwargs)

# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in ``LightningModule.manual_backward``
# `require_backward_grad_sync` will be reset in the
# ddp_plugin ``post_training_step`` hook
if not self.module.automatic_optimization:
self.module.trainer.model.require_backward_grad_sync = False
trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")

elif running_stage == RunningStage.TESTING:
elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")

elif running_stage == RunningStage.EVALUATING:
elif trainer and (trainer.sanity_checking or trainer.validating):
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")

elif running_stage == RunningStage.PREDICTING:
elif trainer and trainer.predicting:
output = self.module.predict(*inputs, **kwargs)
warn_if_output_is_none(output, "predict")

Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
Expand Down Expand Up @@ -156,7 +157,7 @@ def new_process(self, process_idx, trainer, mp_queue):

self.barrier()

results = trainer.train_or_test_or_predict()
results = trainer.run_stage()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(results)
Expand Down Expand Up @@ -218,7 +219,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
# TODO: is there a better way than accessing trainer through model -> trainer?
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
if (
self.lightning_module.trainer.state is TrainerState.FITTING
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)

Expand All @@ -235,7 +240,7 @@ def __recover_child_process_weights(self, best_path, last_path):
# todo, pass also best score

# load last weights
if last_path is not None and not self.lightning_module.trainer.testing:
if last_path is not None and self.lightning_module.trainer.state is TrainerState.FITTING:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
self.lightning_module.load_state_dict(ckpt)

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def init_deepspeed(self):
precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)

if self.lightning_module.trainer.training:
if self.lightning_module.trainer and self.lightning_module.trainer.training:
self._initialize_deepspeed_train(model)
else:
self._initialize_deepspeed_inference(model)
Expand Down Expand Up @@ -249,8 +249,7 @@ def _initialize_deepspeed_train(self, model):
)

# set optimizer for save/load, but deepspeed manages the specific optimizer logic
trainer = self.lightning_module.trainer
trainer.optimizers = [optimizer]
self.lightning_module.trainer.optimizers = [optimizer]
self.model = model

def _initialize_deepspeed_inference(self, model):
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -208,7 +208,7 @@ def _skip_init_connections(self):
Returns: Whether to skip initialization

"""
return torch_distrib.is_initialized() and self.lightning_module.running_stage == RunningStage.TESTING
return torch_distrib.is_initialized() and self.lightning_module.trainer.state is not TrainerState.FITTING

def init_model_parallel_groups(self):
num_model_parallel = 1 # TODO currently no support for vertical model parallel
Expand All @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self):
return self.world_size

def handle_transferred_pipe_module(self) -> None:
if not self.lightning_module.running_stage == RunningStage.TESTING:
if self.lightning_module.trainer.state is TrainerState.FITTING:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
torch_distrib.barrier() # Ensure we await main process initialization
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = self.lightning_module.trainer
Expand All @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None:
# Create pipe_module
model = self.lightning_module
self._find_and_init_pipe_module(model)
if not self.lightning_module.running_stage == RunningStage.TESTING:
if self.lightning_module.trainer.state is TrainerState.FITTING:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand Down Expand Up @@ -35,8 +36,7 @@ def _reinit_optimizers_with_oss(self):
trainer.convert_to_lightning_optimizers()

def _wrap_optimizers(self):
trainer = self.model.trainer
if trainer.testing is True:
if self.model.trainer.state is not TrainerState.FITTING:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return
self._reinit_optimizers_with_oss()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand Down Expand Up @@ -31,8 +32,7 @@ def _reinit_optimizers_with_oss(self):
trainer.optimizers = optimizers

def _wrap_optimizers(self):
trainer = self.model.trainer
if trainer.testing:
if self.model.trainer.state is not TrainerState.FITTING:
return
self._reinit_optimizers_with_oss()

Expand Down
14 changes: 9 additions & 5 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -97,7 +98,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
trainer.save_checkpoint = self.save_checkpoint
self.barrier()

results = trainer.train_or_test_or_predict()
results = trainer.run_stage()

self.__save_end_of_training_weights(self.lightning_module)
self.transfer_distrib_spawn_state_on_fit_end(results)
Expand All @@ -124,7 +125,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# save the last weights
last_path = None
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
if (
self.lightning_module.trainer.state is TrainerState.FITTING
and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
xm.save(self.lightning_module.state_dict(), last_path)

Expand Down Expand Up @@ -214,7 +219,7 @@ def post_dispatch(self) -> None:
# todo, pass also bets score

# load last weights
if last_path and not self.lightning_module.trainer.testing:
if last_path and model.trainer.state is not TrainerState.FITTING:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

Expand All @@ -227,8 +232,7 @@ def __load_weights_on_main_process(self) -> None:
model = self.lightning_module

# load weights if not interrupted
# TODO: check for trainer reference
if on_colab_kaggle() and not model.trainer.testing:
if on_colab_kaggle() and model.trainer.state is TrainerState.FITTING:
self.load_spawn_weights(model)

self._model = model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def start_training(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the training loop
self._results = trainer.run_train()

def start_testing(self, trainer: 'Trainer') -> None:
def start_evaluating(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the test loop
self._results = trainer.run_test()
self._results = trainer.run_evaluate()

def start_predicting(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the predicting loop
Expand Down
Loading