Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use typing forward references #7770

Merged
merged 2 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 32 additions & 36 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,42 +324,38 @@ def log(
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
)

if self._results is not None:
# TODO: if logged twice fail with crash

# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

assert self._current_fx_name is not None
self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx."
)

value = self.__sync(
value,
sync_fn=self.trainer.training_type_plugin.reduce,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)
# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

assert self._current_fx_name is not None
self.trainer.logger_connector.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(f"Logged key: {name} should not contain information about dataloader_idx.")

value = self.__sync(
value,
sync_fn=self.trainer.training_type_plugin.reduce,
sync_dist=sync_dist,
sync_dist_op=sync_dist_op,
sync_dist_group=sync_dist_group,
device=self.device,
)

self._results.log(
name,
value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
)
assert self._results is not None
self._results.log(
name,
value,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
)

def log_dict(
self,
Expand All @@ -378,7 +374,7 @@ def log_dict(
add_dataloader_idx: bool = True,
) -> None:
"""
Log a dictonary of values at once
Log a dictionary of values at once

Example::

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):

def __init__(self, pl_module: LightningModule):
def __init__(self, pl_module: 'pl.LightningModule') -> None:
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ``test_step``.
Expand Down Expand Up @@ -66,7 +66,7 @@ def on_post_move_to_device(self):
pass


def unwrap_lightning_module(wrapped_model) -> LightningModule:
def unwrap_lightning_module(wrapped_model) -> 'pl.LightningModule':
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -53,7 +53,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):

"""

def __init__(self, pl_module: LightningModule):
def __init__(self, pl_module: 'pl.LightningModule') -> None:
super().__init__(pl_module)
_ignore_scalar_return_in_dp()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase


class LightningDistributedModule(_LightningModuleWrapperBase):

def __init__(self, pl_module: LightningModule):
def __init__(self, pl_module: 'pl.LightningModule') -> None:
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, amp_level: str = "O2") -> None:
def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
return amp.master_params(optimizer)

def dispatch(self, trainer: "pl.Trainer") -> None:
def dispatch(self, trainer: 'pl.Trainer') -> None:
if not self._connected:
accelerator = trainer.accelerator
_, accelerator.optimizers = amp.initialize(
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
Expand Down Expand Up @@ -51,7 +51,7 @@ def remove_module_hooks(model: torch.nn.Module) -> None:

class LightningDeepSpeedModule(_LightningModuleWrapperBase):

def __init__(self, pl_module: LightningModule, precision: int):
def __init__(self, pl_module: 'pl.LightningModule', precision: int) -> None:
super().__init__(pl_module)
self.precision = precision

Expand Down Expand Up @@ -378,7 +378,7 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]:
def init_optimizers(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> Tuple[List, List, List]:
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train`
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
Expand Down Expand Up @@ -99,7 +99,7 @@ def torch_distributed_backend(self):
return torch_backend

@staticmethod
def configure_sync_batchnorm(model: LightningModule) -> LightningModule:
def configure_sync_batchnorm(model: 'pl.LightningModule') -> 'pl.LightningModule':
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.

Expand All @@ -112,8 +112,7 @@ def configure_sync_batchnorm(model: LightningModule) -> LightningModule:
Return:
LightningModule with batchnorm layers synchronized between process groups
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return model
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

@contextmanager
def block_backward_sync(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/accelerators/test_multi_nodes_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,5 @@ def backward(self, loss, optimizer, optimizer_idx):
}

# we don't want to enable val metrics during steps because it is not something that users should do
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics
# on purpose DO NOT allow b_step... it's silly to monitor val step metrics
assert set(trainer.callback_metrics) == {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}