Skip to content

Commit

Permalink
Merge 45a010f into 4f391bc
Browse files Browse the repository at this point in the history
  • Loading branch information
EliaCereda authored Mar 6, 2021
2 parents 4f391bc + 45a010f commit c46ae62
Show file tree
Hide file tree
Showing 42 changed files with 497 additions and 638 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


Expand All @@ -26,9 +32,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))


- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


### Deprecated


- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


### Removed

- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
Expand Down Expand Up @@ -80,8 +81,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
def start_training(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_training(trainer)

def start_testing(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_testing(trainer)
def start_evaluating(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_evaluating(trainer)

def start_predicting(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_predicting(trainer)
Expand Down Expand Up @@ -323,7 +324,7 @@ def setup_optimizers(self, trainer: 'Trainer') -> None:
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.testing:
if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING):
return
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
trainer=trainer, model=self.lightning_module
Expand Down Expand Up @@ -417,7 +418,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
@property
def results(self) -> Any:
"""
The results of the last training/testing run will be cached within the training type plugin.
The results of the last run will be cached within the training type plugin.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
return self.training_type_plugin.results
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,13 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.patience = callback_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
from pytorch_lightning.trainer.states import TrainerState
if trainer.state != TrainerState.FITTING or trainer.sanity_checking:
return

self._run_early_stopping_check(trainer, pl_module)
self._run_early_stopping_check(trainer)

def _run_early_stopping_check(self, trainer, pl_module):
def _run_early_stopping_check(self, trainer):
"""
Checks whether the early stopping condition is met
and if so tells the trainer to stop the training.
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,14 @@ def save_checkpoint(self, trainer, pl_module):
epoch = trainer.current_epoch
global_step = trainer.global_step

from pytorch_lightning.trainer.states import TrainerState
if (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or self._last_global_step_saved == global_step # already saved at the last step
):
return
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down Expand Up @@ -412,7 +411,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if not trainer.running_sanity_check:
if trainer.sanity_checking:
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
else:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, self.total_val_batches)
Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -44,8 +44,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

if TYPE_CHECKING:
from pytorch_lightning.trainer.states import RunningStage
log = logging.getLogger(__name__)


Expand All @@ -69,7 +67,6 @@ class LightningModule(
"on_gpu",
"current_epoch",
"global_step",
"running_stage",
"global_rank",
"local_rank",
"logger",
Expand Down Expand Up @@ -172,10 +169,6 @@ def automatic_optimization(self) -> bool:
"""
return self._automatic_optimization

@property
def running_stage(self) -> Optional["RunningStage"]:
return self.trainer._running_stage if self.trainer else None

@automatic_optimization.setter
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization
Expand Down
13 changes: 6 additions & 7 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.warnings import WarningCache

Expand All @@ -43,28 +42,28 @@ def __init__(self, pl_module: LightningModule):
self.module = pl_module

def forward(self, *inputs, **kwargs):
running_stage = self.module.running_stage
trainer = self.module.trainer

if running_stage == RunningStage.TRAINING:
if trainer and trainer.training:
output = self.module.training_step(*inputs, **kwargs)

# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in ``LightningModule.manual_backward``
# `require_backward_grad_sync` will be reset in the
# ddp_plugin ``post_training_step`` hook
if not self.module.automatic_optimization:
self.module.trainer.model.require_backward_grad_sync = False
trainer.model.require_backward_grad_sync = False
warn_if_output_is_none(output, "training_step")

elif running_stage == RunningStage.TESTING:
elif trainer and trainer.testing:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")

elif running_stage == RunningStage.EVALUATING:
elif trainer and (trainer.sanity_checking or trainer.validating):
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")

elif running_stage == RunningStage.PREDICTING:
elif trainer and trainer.predicting:
output = self.module.predict(*inputs, **kwargs)
warn_if_output_is_none(output, "predict")

Expand Down
16 changes: 9 additions & 7 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
Expand Down Expand Up @@ -103,7 +104,7 @@ def start_training(self, trainer):
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_testing(self, trainer):
def start_evaluating(self, trainer):
mp.spawn(self.new_process, **self.mp_spawn_kwargs)

def start_predicting(self, trainer):
Expand Down Expand Up @@ -152,7 +153,7 @@ def new_process(self, process_idx, trainer, mp_queue):

self.barrier()

results = trainer.train_or_test_or_predict()
results = trainer.run_stage()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(results)
Expand Down Expand Up @@ -204,7 +205,6 @@ def on_save(self, checkpoint: dict) -> dict:
return checkpoint

def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

Expand All @@ -213,8 +213,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):

# save the last weights
last_path = None
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
if (
self.lightning_module.trainer.state == TrainerState.FITTING
and best_model_path is not None
and len(best_model_path) > 0
):
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)

Expand All @@ -224,14 +227,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
self.mp_queue.put(results)

def __recover_child_process_weights(self, best_path, last_path):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback:
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score

# load last weights
if last_path is not None and not self.lightning_module.trainer.testing:
if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
self.lightning_module.load_state_dict(ckpt)

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def init_deepspeed(self):
precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)

if self.lightning_module.trainer.training:
if self.lightning_module.trainer and self.lightning_module.trainer.training:
self._initialize_deepspeed_train(model)
else:
self._initialize_deepspeed_inference(model)
Expand Down Expand Up @@ -249,8 +249,7 @@ def _initialize_deepspeed_train(self, model):
)

# set optimizer for save/load, but deepspeed manages the specific optimizer logic
trainer = self.lightning_module.trainer
trainer.optimizers = [optimizer]
self.lightning_module.trainer.optimizers = [optimizer]
self.model = model

def _initialize_deepspeed_inference(self, model):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def start_training(self, trainer):
# Make sure all workers have finished training before returning to the user
hvd.join()

def start_testing(self, trainer):
def start_evaluating(self, trainer):
with ExitStack():
self._results = trainer.run_test()
self._results = trainer.run_evaluate()

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/plugins/training_type/rpc_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -208,7 +208,7 @@ def _skip_init_connections(self):
Returns: Whether to skip initialization
"""
return torch_distrib.is_initialized() and self.lightning_module.running_stage == RunningStage.TESTING
return torch_distrib.is_initialized() and self.lightning_module.trainer.state != TrainerState.FITTING

def init_model_parallel_groups(self):
num_model_parallel = 1 # TODO currently no support for vertical model parallel
Expand All @@ -231,7 +231,7 @@ def _infer_check_num_gpus(self):
return self.world_size

def handle_transferred_pipe_module(self) -> None:
if not self.lightning_module.running_stage == RunningStage.TESTING:
if self.lightning_module.trainer.state == TrainerState.FITTING:
torch_distrib.barrier() # Ensure we await main process initialization
# Add trainer/configure_optimizers to the pipe model for access in all worker processes
rpc_pipe.PipeModel.trainer = self.lightning_module.trainer
Expand All @@ -243,7 +243,7 @@ def init_pipe_module(self) -> None:
# Create pipe_module
model = self.lightning_module
self._find_and_init_pipe_module(model)
if not self.lightning_module.running_stage == RunningStage.TESTING:
if self.lightning_module.trainer.state == TrainerState.FITTING:
torch_distrib.barrier() # Ensure we join main process initialization
model.sequential_module.foreach_worker(register_optimizers, include_self=True)

Expand Down Expand Up @@ -333,9 +333,9 @@ def start_training(self, trainer) -> None:
if self.main_rpc_process:
super().start_training(trainer)

def start_testing(self, trainer) -> None:
def start_evaluating(self, trainer) -> None:
if self.main_rpc_process:
super().start_testing(trainer)
super().start_evaluating(trainer)


class LightningPipeModule(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand Down Expand Up @@ -48,8 +49,7 @@ def _reinit_optimizers_with_oss(self):
trainer.convert_to_lightning_optimizers()

def _wrap_optimizers(self):
trainer = self.model.trainer
if trainer.testing is True:
if self.model.trainer.state != TrainerState.FITTING:
return
self._reinit_optimizers_with_oss()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only

if _FAIRSCALE_AVAILABLE:
Expand Down Expand Up @@ -44,8 +45,7 @@ def _reinit_optimizers_with_oss(self):
trainer.optimizers = optimizers

def _wrap_optimizers(self):
trainer = self.model.trainer
if trainer.testing:
if self.model.trainer.state != TrainerState.FITTING:
return
self._reinit_optimizers_with_oss()

Expand Down
Loading

0 comments on commit c46ae62

Please sign in to comment.