diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 63d93643c0e8d..79fbf9eab5fdb 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,6 +5,6 @@ matplotlib>3.1, <3.6.2 omegaconf >=2.0.5, <2.4.0 hydra-core >=1.0.5, <1.4.0 -jsonargparse[signatures] >=4.18.0, <4.22.0 +jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/future-annotations rich >=12.3.0, <=13.0.1 tensorboardX >=2.2, <=2.6 # min version is set by torch.onnx missing attribute diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index b27df0a34de3c..f6533a600cb99 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import ABC -from typing import Any, Dict +from typing import Any import lightning.pytorch as pl from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator @@ -25,14 +27,14 @@ class Accelerator(_Accelerator, ABC): .. warning:: Writing your own accelerator is an :ref:`experimental ` feature. """ - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: """Setup plugins for the trainer fit and creates optimizers. Args: trainer: the trainer instance """ - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get stats for a given device. Args: diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index c12bd0afe0574..ac29e88ca17bb 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from __future__ import annotations + +from typing import Any import torch from lightning_utilities.core.imports import RequirementCache @@ -35,7 +37,7 @@ def setup_device(self, device: torch.device) -> None: if device.type != "cpu": raise MisconfigurationException(f"Device should be CPU, got {device} instead.") - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get CPU stats from ``psutil`` package.""" return get_cpu_stats() @@ -43,12 +45,12 @@ def teardown(self) -> None: pass @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> int: + def parse_devices(devices: int | str | list[int]) -> int: """Accelerator device parsing logic.""" return _parse_cpu_cores(devices) @staticmethod - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -80,7 +82,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _PSUTIL_AVAILABLE = RequirementCache("psutil") -def get_cpu_stats() -> Dict[str, float]: +def get_cpu_stats() -> dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching CPU device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 9161bc92f8d0c..f550a9f76f08d 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import shutil import subprocess -from typing import Any, Dict, List, Optional, Union +from typing import Any import torch @@ -44,7 +46,7 @@ def setup_device(self, device: torch.device) -> None: _check_cuda_matmul_precision(device) torch.cuda.set_device(device) - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: # TODO refactor input from trainer to local_rank @four4fish self.set_nvidia_flags(trainer.local_rank) _clear_cuda_memory() @@ -57,7 +59,7 @@ def set_nvidia_flags(local_rank: int) -> None: devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Gets stats for the given GPU device. Args: @@ -76,12 +78,12 @@ def teardown(self) -> None: _clear_cuda_memory() @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_cuda=True) @staticmethod - def get_parallel_devices(devices: List[int]) -> List[torch.device]: + def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @@ -103,7 +105,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover +def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 03ba218604128..d63ad638ff077 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from __future__ import annotations + +from typing import Any import torch @@ -39,7 +41,7 @@ def setup_device(self, device: torch.device) -> None: if device.type != "mps": raise MisconfigurationException(f"Device should be MPS, got {device} instead.") - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get M1 (cpu + gpu) stats from ``psutil`` package.""" return get_device_stats() @@ -47,12 +49,12 @@ def teardown(self) -> None: pass @staticmethod - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: int | str | list[int]) -> list[int] | None: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_mps=True) @staticmethod - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None @@ -84,7 +86,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _SWAP_PERCENT = "M1_swap_percent" -def get_device_stats() -> Dict[str, float]: +def get_device_stats() -> dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching MPS device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index fe9c1261c9b39..417a2a934c44b 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from __future__ import annotations + +from typing import Any from lightning.fabric.accelerators import _AcceleratorRegistry from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator @@ -25,7 +27,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator): .. warning:: Use of this accelerator beyond import and instantiation is experimental. """ - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Gets stats for the given XLA device. Args: diff --git a/src/lightning/pytorch/callbacks/batch_size_finder.py b/src/lightning/pytorch/callbacks/batch_size_finder.py index 02b78a22a5a7f..36b4214da10af 100644 --- a/src/lightning/pytorch/callbacks/batch_size_finder.py +++ b/src/lightning/pytorch/callbacks/batch_size_finder.py @@ -18,7 +18,7 @@ Finds optimal batch size """ -from typing import Optional +from __future__ import annotations import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -119,7 +119,7 @@ def __init__( if mode not in self.SUPPORTED_MODES: raise ValueError(f"`mode` should be either of {self.SUPPORTED_MODES}") - self.optimal_batch_size: Optional[int] = init_val + self.optimal_batch_size: int | None = init_val self._mode = mode self._steps_per_trial = steps_per_trial self._init_val = init_val @@ -127,7 +127,7 @@ def __init__( self._batch_arg_name = batch_arg_name self._early_exit = False - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None: if trainer._accelerator_connector.is_distributed: raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.") # TODO: check if this can be enabled (#4040) @@ -167,7 +167,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O " If this is not the intended behavior, please remove either one." ) - def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def scale_batch_size(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: new_size = _scale_batch_size( trainer, self._mode, @@ -181,17 +181,17 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule if self._early_exit: raise _TunerExitException() - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.scale_batch_size(trainer, pl_module) - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if trainer.sanity_checking or trainer.state.fn != "validate": return self.scale_batch_size(trainer, pl_module) - def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.scale_batch_size(trainer, pl_module) - def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.scale_batch_size(trainer, pl_module) diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 5a7ac0cca0f9b..1201b74b049e5 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -13,7 +13,9 @@ # limitations under the License. r"""Base class used to build new callbacks.""" -from typing import Any, Dict, Optional, Type +from __future__ import annotations + +from typing import Any from torch import Tensor from torch.optim import Optimizer @@ -39,7 +41,7 @@ def state_key(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_key(self) -> Type["Callback"]: + def _legacy_state_key(self) -> type[Callback]: """State key for checkpoints saved prior to version 1.5.0.""" return type(self) @@ -52,31 +54,31 @@ def _generate_state_key(self, **kwargs: Any) -> str: """ return f"{self.__class__.__qualname__}{repr(kwargs)}" - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: """Called when fit, validate, test, predict, or tune begins.""" - def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: """Called when fit, validate, test, predict, or tune ends.""" - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when fit begins.""" - def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when fit ends.""" - def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_sanity_check_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the validation sanity check starts.""" - def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_sanity_check_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the validation sanity check ends.""" def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int ) -> None: """Called when the train batch begins.""" def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: """Called when the train batch ends. @@ -85,10 +87,10 @@ def on_train_batch_end( loss returned from ``training_step``. """ - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the train epoch begins.""" - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the train epoch ends. To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the @@ -116,28 +118,28 @@ def on_train_epoch_end(self, trainer, pl_module): pl_module.training_step_outputs.clear() """ - def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the val epoch begins.""" - def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the val epoch ends.""" - def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the test epoch begins.""" - def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the test epoch ends.""" - def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the predict epoch begins.""" - def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the predict epoch ends.""" def on_validation_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -146,9 +148,9 @@ def on_validation_batch_start( def on_validation_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -157,8 +159,8 @@ def on_validation_batch_end( def on_test_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -167,9 +169,9 @@ def on_test_batch_start( def on_test_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -178,8 +180,8 @@ def on_test_batch_end( def on_predict_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -188,8 +190,8 @@ def on_predict_batch_start( def on_predict_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, @@ -197,34 +199,34 @@ def on_predict_batch_end( ) -> None: """Called when the predict batch ends.""" - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the train begins.""" - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the train ends.""" - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the validation loop begins.""" - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the validation loop ends.""" - def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the test begins.""" - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the test ends.""" - def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the predict begins.""" - def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when predict ends.""" - def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + def on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None: """Called when any trainer execution is interrupted by an exception.""" - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate callback's ``state_dict``. Returns: @@ -232,7 +234,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``. Args: @@ -241,7 +243,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: pass def on_save_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict[str, Any] ) -> None: r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save. @@ -252,7 +254,7 @@ def on_save_checkpoint( """ def on_load_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: dict[str, Any] ) -> None: r"""Called when loading a model checkpoint, use to reload state. @@ -262,16 +264,16 @@ def on_load_checkpoint( checkpoint: the full checkpoint dictionary that got loaded by the Trainer. """ - def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None: + def on_before_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule, loss: Tensor) -> None: """Called before ``loss.backward()``.""" - def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called after ``loss.backward()`` and before optimizers are stepped.""" def on_before_optimizer_step( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer + self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer ) -> None: """Called before ``optimizer.step()``.""" - def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None: + def on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Optimizer) -> None: """Called before ``optimizer.zero_grad()``.""" diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index 4442e73dbbc6e..834b062500f67 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -18,7 +18,9 @@ Monitors and logs device stats during training. """ -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any import lightning.pytorch as pl from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE @@ -52,13 +54,13 @@ class DeviceStatsMonitor(Callback): trainer = Trainer(callbacks=[device_stats]) """ - def __init__(self, cpu_stats: Optional[bool] = None) -> None: + def __init__(self, cpu_stats: bool | None = None) -> None: self._cpu_stats = cpu_stats def setup( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, stage: str, ) -> None: if stage != "fit": @@ -74,7 +76,7 @@ def setup( f"`DeviceStatsMonitor` cannot log CPU stats as `psutil` is not installed. {str(_PSUTIL_AVAILABLE)} " ) - def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None: + def _get_and_log_device_stats(self, trainer: pl.Trainer, key: str) -> None: if not trainer._logger_connector.should_update_logs: return @@ -97,19 +99,19 @@ def _get_and_log_device_stats(self, trainer: "pl.Trainer", key: str) -> None: logger.log_metrics(prefixed_device_stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int ) -> None: self._get_and_log_device_stats(trainer, "on_train_batch_start") def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: self._get_and_log_device_stats(trainer, "on_train_batch_end") def on_validation_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -118,9 +120,9 @@ def on_validation_batch_start( def on_validation_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -129,8 +131,8 @@ def on_validation_batch_end( def on_test_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -139,9 +141,9 @@ def on_test_batch_start( def on_test_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -149,5 +151,5 @@ def on_test_batch_end( self._get_and_log_device_stats(trainer, "on_test_batch_end") -def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: +def _prefix_metric_keys(metrics_dict: dict[str, float], prefix: str, separator: str) -> dict[str, float]: return {prefix + separator + k: v for k, v in metrics_dict.items()} diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index d6996e408164a..d0bff4c74517f 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -15,8 +15,10 @@ Monitor a metric and stop training when it stops improving. """ +from __future__ import annotations + import logging -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable import torch from torch import Tensor @@ -94,9 +96,9 @@ def __init__( mode: str = "min", strict: bool = True, check_finite: bool = True, - stopping_threshold: Optional[float] = None, - divergence_threshold: Optional[float] = None, - check_on_train_epoch_end: Optional[bool] = None, + stopping_threshold: float | None = None, + divergence_threshold: float | None = None, + check_on_train_epoch_end: bool | None = None, log_rank_zero_only: bool = False, ): super().__init__() @@ -125,13 +127,13 @@ def __init__( def state_key(self) -> str: return self._generate_state_key(monitor=self.monitor, mode=self.mode) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: if self._check_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch or multiple training epochs without # validation, then we run after validation instead of on train epoch end self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool: + def _validate_condition_metric(self, logs: dict[str, Tensor]) -> bool: monitor_val = logs.get(self.monitor) error_msg = ( @@ -154,7 +156,7 @@ def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool: def monitor_op(self) -> Callable: return self.mode_dict[self.mode] - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "wait_count": self.wait_count, "stopped_epoch": self.stopped_epoch, @@ -162,28 +164,28 @@ def state_dict(self) -> Dict[str, Any]: "patience": self.patience, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.wait_count = state_dict["wait_count"] self.stopped_epoch = state_dict["stopped_epoch"] self.best_score = state_dict["best_score"] self.patience = state_dict["patience"] - def _should_skip_check(self, trainer: "pl.Trainer") -> bool: + def _should_skip_check(self, trainer: pl.Trainer) -> bool: from lightning.pytorch.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: + def _run_early_stopping_check(self, trainer: pl.Trainer) -> None: """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" logs = trainer.callback_metrics @@ -203,7 +205,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: if reason and self.verbose: self._log_info(trainer, reason, self.log_rank_zero_only) - def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[str]]: + def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, str | None]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): @@ -254,7 +256,7 @@ def _improvement_message(self, current: Tensor) -> str: return msg @staticmethod - def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: + def _log_info(trainer: pl.Trainer | None, message: str, log_rank_zero_only: bool) -> None: rank = _get_rank( strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type] ) diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index e4a68024ea5f4..5f7cf8e55f67a 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. r"""Finetuning Callback ^^^^^^^^^^^^^^^^^^^^ Freeze and unfreeze models for finetuning purposes.""" +from __future__ import annotations + import logging -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Callable, Generator, Iterable import torch from torch.nn import Module, ModuleDict @@ -77,15 +79,15 @@ class BaseFinetuning(Callback): """ def __init__(self) -> None: - self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {} + self._internal_optimizer_metadata: dict[int, list[dict[str, Any]]] = {} self._restarting = False - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._restarting = True if "internal_optimizer_metadata" in state_dict: # noqa: SIM401 self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"] @@ -93,7 +95,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # compatibility to load from old checkpoints before PR #11887 self._internal_optimizer_metadata = state_dict # type: ignore[assignment] - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: # restore the param_groups created during the previous training. if self._restarting: named_parameters = dict(pl_module.named_parameters()) @@ -105,7 +107,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self._restarting = False @staticmethod - def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: + def flatten_modules(modules: Module | Iterable[Module | Iterable]) -> list[Module]: """This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves. @@ -132,7 +134,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - @staticmethod def filter_params( - modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True, requires_grad: bool = True + modules: Module | Iterable[Module | Iterable], train_bn: bool = True, requires_grad: bool = True ) -> Generator: """Yields the `requires_grad` parameters of a given module or list of modules. @@ -153,7 +155,7 @@ def filter_params( yield param @staticmethod - def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None: + def make_trainable(modules: Module | Iterable[Module | Iterable]) -> None: """Unfreezes the parameters of the provided modules. Args: @@ -181,7 +183,7 @@ def freeze_module(module: Module) -> None: param.requires_grad = False @staticmethod - def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None: + def freeze(modules: Module | Iterable[Module | Iterable], train_bn: bool = True) -> None: """Freezes the parameters of the provided modules. Args: @@ -199,7 +201,7 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: BaseFinetuning.freeze_module(mod) @staticmethod - def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: + def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> list: """This function is used to exclude any parameter which already exists in this optimizer. Args: @@ -228,9 +230,9 @@ def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: @staticmethod def unfreeze_and_add_param_group( - modules: Union[Module, Iterable[Union[Module, Iterable]]], + modules: Module | Iterable[Module | Iterable], optimizer: Optimizer, - lr: Optional[float] = None, + lr: float | None = None, initial_denom_lr: float = 10.0, train_bn: bool = True, ) -> None: @@ -254,11 +256,11 @@ def unfreeze_and_add_param_group( if params: optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr}) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: self.freeze_before_training(pl_module) @staticmethod - def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: + def _apply_mapping_to_param_groups(param_groups: list[dict[str, Any]], mapping: dict) -> list[dict[str, Any]]: output = [] for g in param_groups: # skip params to save memory @@ -269,10 +271,10 @@ def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: def _store( self, - pl_module: "pl.LightningModule", + pl_module: pl.LightningModule, opt_idx: int, num_param_groups: int, - current_param_groups: List[Dict[str, Any]], + current_param_groups: list[dict[str, Any]], ) -> None: mapping = {p: n for n, p in pl_module.named_parameters()} if opt_idx not in self._internal_optimizer_metadata: @@ -285,7 +287,7 @@ def _store( self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping) ) - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Called when the epoch begins.""" for opt_idx, optimizer in enumerate(trainer.optimizers): num_param_groups = len(optimizer.param_groups) @@ -293,11 +295,11 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo current_param_groups = optimizer.param_groups self._store(pl_module, opt_idx, num_param_groups, current_param_groups) - def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None: + def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer) -> None: """Override to add your unfreeze logic.""" raise NotImplementedError - def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: """Override to add your freeze logic.""" raise NotImplementedError @@ -337,7 +339,7 @@ def __init__( unfreeze_backbone_at_epoch: int = 10, lambda_func: Callable = multiplicative, backbone_initial_ratio_lr: float = 10e-2, - backbone_initial_lr: Optional[float] = None, + backbone_initial_lr: float | None = None, should_align: bool = True, initial_denom_lr: float = 10.0, train_bn: bool = True, @@ -349,25 +351,25 @@ def __init__( self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch self.lambda_func: Callable = lambda_func self.backbone_initial_ratio_lr: float = backbone_initial_ratio_lr - self.backbone_initial_lr: Optional[float] = backbone_initial_lr + self.backbone_initial_lr: float | None = backbone_initial_lr self.should_align: bool = should_align self.initial_denom_lr: float = initial_denom_lr self.train_bn: bool = train_bn self.verbose: bool = verbose self.rounding: int = rounding - self.previous_backbone_lr: Optional[float] = None + self.previous_backbone_lr: float | None = None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, "previous_backbone_lr": self.previous_backbone_lr, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.previous_backbone_lr = state_dict["previous_backbone_lr"] super().load_state_dict(state_dict) - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """ Raises: MisconfigurationException: @@ -377,10 +379,10 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - return super().on_fit_start(trainer, pl_module) raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute") - def freeze_before_training(self, pl_module: "pl.LightningModule") -> None: + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: self.freeze(pl_module.backbone) - def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer) -> None: + def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer) -> None: """Called when the epoch begins.""" if epoch == self.unfreeze_backbone_at_epoch: current_lr = optimizer.param_groups[0]["lr"] diff --git a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py index 9c0b1a741f53d..1f754c0a039bf 100644 --- a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py +++ b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py @@ -20,7 +20,9 @@ """ -from typing import Any, Dict +from __future__ import annotations + +from typing import Any import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -62,7 +64,7 @@ class GradientAccumulationScheduler(Callback): >>> trainer = Trainer(callbacks=[accumulator]) """ - def __init__(self, scheduling: Dict[int, int]): + def __init__(self, scheduling: dict[int, int]): super().__init__() if not scheduling: # empty dict error @@ -98,7 +100,7 @@ def get_accumulate_grad_batches(self, epoch: int) -> int: break return accumulate_grad_batches - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Performns a configuration validation before training starts and raises errors for incompatible settings.""" @@ -139,5 +141,5 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") " callback. Either remove `accumulate_grad_batches` from the Trainer or remove the callback." ) - def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, *_: Any) -> None: trainer.accumulate_grad_batches = self.get_accumulate_grad_batches(trainer.current_epoch) diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index e062656313eab..976f178815502 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -16,7 +16,9 @@ Create a simple callback on the fly using lambda functions. """ -from typing import Callable, Optional +from __future__ import annotations + +from typing import Callable from lightning.pytorch.callbacks.callback import Callback @@ -36,43 +38,43 @@ class LambdaCallback(Callback): def __init__( self, - setup: Optional[Callable] = None, - teardown: Optional[Callable] = None, - on_fit_start: Optional[Callable] = None, - on_fit_end: Optional[Callable] = None, - on_sanity_check_start: Optional[Callable] = None, - on_sanity_check_end: Optional[Callable] = None, - on_train_batch_start: Optional[Callable] = None, - on_train_batch_end: Optional[Callable] = None, - on_train_epoch_start: Optional[Callable] = None, - on_train_epoch_end: Optional[Callable] = None, - on_validation_epoch_start: Optional[Callable] = None, - on_validation_epoch_end: Optional[Callable] = None, - on_test_epoch_start: Optional[Callable] = None, - on_test_epoch_end: Optional[Callable] = None, - on_validation_batch_start: Optional[Callable] = None, - on_validation_batch_end: Optional[Callable] = None, - on_test_batch_start: Optional[Callable] = None, - on_test_batch_end: Optional[Callable] = None, - on_train_start: Optional[Callable] = None, - on_train_end: Optional[Callable] = None, - on_validation_start: Optional[Callable] = None, - on_validation_end: Optional[Callable] = None, - on_test_start: Optional[Callable] = None, - on_test_end: Optional[Callable] = None, - on_exception: Optional[Callable] = None, - on_save_checkpoint: Optional[Callable] = None, - on_load_checkpoint: Optional[Callable] = None, - on_before_backward: Optional[Callable] = None, - on_after_backward: Optional[Callable] = None, - on_before_optimizer_step: Optional[Callable] = None, - on_before_zero_grad: Optional[Callable] = None, - on_predict_start: Optional[Callable] = None, - on_predict_end: Optional[Callable] = None, - on_predict_batch_start: Optional[Callable] = None, - on_predict_batch_end: Optional[Callable] = None, - on_predict_epoch_start: Optional[Callable] = None, - on_predict_epoch_end: Optional[Callable] = None, + setup: Callable | None = None, + teardown: Callable | None = None, + on_fit_start: Callable | None = None, + on_fit_end: Callable | None = None, + on_sanity_check_start: Callable | None = None, + on_sanity_check_end: Callable | None = None, + on_train_batch_start: Callable | None = None, + on_train_batch_end: Callable | None = None, + on_train_epoch_start: Callable | None = None, + on_train_epoch_end: Callable | None = None, + on_validation_epoch_start: Callable | None = None, + on_validation_epoch_end: Callable | None = None, + on_test_epoch_start: Callable | None = None, + on_test_epoch_end: Callable | None = None, + on_validation_batch_start: Callable | None = None, + on_validation_batch_end: Callable | None = None, + on_test_batch_start: Callable | None = None, + on_test_batch_end: Callable | None = None, + on_train_start: Callable | None = None, + on_train_end: Callable | None = None, + on_validation_start: Callable | None = None, + on_validation_end: Callable | None = None, + on_test_start: Callable | None = None, + on_test_end: Callable | None = None, + on_exception: Callable | None = None, + on_save_checkpoint: Callable | None = None, + on_load_checkpoint: Callable | None = None, + on_before_backward: Callable | None = None, + on_after_backward: Callable | None = None, + on_before_optimizer_step: Callable | None = None, + on_before_zero_grad: Callable | None = None, + on_predict_start: Callable | None = None, + on_predict_end: Callable | None = None, + on_predict_batch_start: Callable | None = None, + on_predict_batch_end: Callable | None = None, + on_predict_epoch_start: Callable | None = None, + on_predict_epoch_end: Callable | None = None, ): for k, v in locals().items(): if k == "self": diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index c594d4af4542e..28294b6ba01d6 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -17,7 +17,7 @@ Finds optimal learning rate """ -from typing import Optional +from __future__ import annotations import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -85,7 +85,7 @@ def __init__( max_lr: float = 1, num_training_steps: int = 100, mode: str = "exponential", - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, update_attr: bool = True, attr_name: str = "", ) -> None: @@ -102,9 +102,9 @@ def __init__( self._attr_name = attr_name self._early_exit = False - self.lr_finder: Optional[_LRFinder] = None + self.lr_finder: _LRFinder | None = None - def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def lr_find(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: with isolate_rng(): self.optimal_lr = _lr_find( trainer, @@ -121,5 +121,5 @@ def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Non if self._early_exit: raise _TunerExitException() - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.lr_find(trainer, pl_module) diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index d938db61d6d09..794b7204c2f80 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -19,9 +19,11 @@ Monitor and logs learning rate for lr schedulers during training. """ +from __future__ import annotations + import itertools from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Type +from typing import Any, DefaultDict import torch from torch.optim.optimizer import Optimizer @@ -86,15 +88,15 @@ def configure_optimizer(self): return [optimizer], [lr_scheduler] """ - def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None: + def __init__(self, logging_interval: str | None = None, log_momentum: bool = False) -> None: if logging_interval not in (None, "step", "epoch"): raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.") self.logging_interval = logging_interval self.log_momentum = log_momentum - self.lrs: Dict[str, List[float]] = {} + self.lrs: dict[str, list[float]] = {} - def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + def on_train_start(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: """Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in the case of multiple parameter groups. @@ -125,7 +127,7 @@ def _check_no_key(key: str) -> bool: ) # Find names for schedulers - names: List[List[str]] = [] + names: list[list[str]] = [] ( sched_hparam_keys, optimizers_with_scheduler, @@ -146,7 +148,7 @@ def _check_no_key(key: str) -> bool: self.lrs = {name: [] for name in names_flatten} self.last_momentum_values = {name + "-momentum": None for name in names_flatten} - def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + def on_train_batch_start(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: if not trainer._logger_connector.should_update_logs: return @@ -158,7 +160,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) for logger in trainer.loggers: logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) - def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) @@ -167,7 +169,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) for logger in trainer.loggers: logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) - def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: + def _extract_stats(self, trainer: pl.Trainer, interval: str) -> dict[str, float]: latest_stat = {} ( @@ -200,7 +202,7 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa return latest_stat - def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]: + def _get_lr_momentum_stat(self, optimizer: Optimizer, names: list[str]) -> dict[str, float]: lr_momentum_stat = {} param_groups = optimizer.param_groups use_betas = "betas" in optimizer.defaults @@ -215,12 +217,12 @@ def _get_lr_momentum_stat(self, optimizer: Optimizer, names: List[str]) -> Dict[ return lr_momentum_stat - def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: + def _extract_lr(self, param_group: dict[str, Any], name: str) -> dict[str, Any]: lr = param_group["lr"] self.lrs[name].append(lr) return {name: lr} - def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: + def _remap_keys(self, names: list[list[str]], token: str = "/pg1") -> None: """This function is used the remap the keys if param groups for a given optimizer increased.""" for group_new_names in names: for new_name in group_new_names: @@ -230,7 +232,7 @@ def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: elif new_name not in self.lrs: self.lrs[new_name] = [] - def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]: + def _extract_momentum(self, param_group: dict[str, list], name: str, use_betas: bool) -> dict[str, float]: if not self.log_momentum: return {} @@ -239,14 +241,14 @@ def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: return {name: momentum} def _add_prefix( - self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int] + self, name: str, optimizer_cls: type[Optimizer], seen_optimizer_types: DefaultDict[type[Optimizer], int] ) -> str: if optimizer_cls not in seen_optimizer_types: return name count = seen_optimizer_types[optimizer_cls] return name + f"-{count - 1}" if count > 1 else name - def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: + def _add_suffix(self, name: str, param_groups: list[dict], param_group_index: int, use_names: bool = True) -> str: if len(param_groups) > 1: if not use_names: return f"{name}/pg{param_group_index+1}" @@ -257,7 +259,7 @@ def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: in return f"{name}/{pg_name}" if pg_name else name return name - def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: + def _duplicate_param_group_names(self, param_groups: list[dict]) -> set[str]: names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)] unique = set(names) if len(names) == len(unique): @@ -266,13 +268,13 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: def _find_names_from_schedulers( self, - lr_scheduler_configs: List[LRSchedulerConfig], - ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]: + lr_scheduler_configs: list[LRSchedulerConfig], + ) -> tuple[list[list[str]], list[Optimizer], DefaultDict[type[Optimizer], int]]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] - seen_optimizers: List[Optimizer] = [] - seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int) + seen_optimizers: list[Optimizer] = [] + seen_optimizer_types: DefaultDict[type[Optimizer], int] = defaultdict(int) for config in lr_scheduler_configs: sch = config.scheduler name = config.name if config.name is not None else "lr-" + sch.optimizer.__class__.__name__ @@ -286,10 +288,10 @@ def _find_names_from_schedulers( def _find_names_from_optimizers( self, - optimizers: List[Any], - seen_optimizers: List[Optimizer], - seen_optimizer_types: DefaultDict[Type[Optimizer], int], - ) -> Tuple[List[List[str]], List[Optimizer]]: + optimizers: list[Any], + seen_optimizers: list[Optimizer], + seen_optimizer_types: DefaultDict[type[Optimizer], int], + ) -> tuple[list[list[str]], list[Optimizer]]: names = [] optimizers_without_scheduler = [] @@ -312,10 +314,10 @@ def _check_duplicates_and_update_name( self, optimizer: Optimizer, name: str, - seen_optimizers: List[Optimizer], - seen_optimizer_types: DefaultDict[Type[Optimizer], int], - lr_scheduler_config: Optional[LRSchedulerConfig], - ) -> List[str]: + seen_optimizers: list[Optimizer], + seen_optimizer_types: DefaultDict[type[Optimizer], int], + lr_scheduler_config: LRSchedulerConfig | None, + ) -> list[str]: seen_optimizers.append(optimizer) optimizer_cls = type(optimizer) if lr_scheduler_config is None or lr_scheduler_config.name is None: diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 40e5cfc008bae..9dbceab2d4f6a 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -17,6 +17,8 @@ Automatically save model checkpoints during training. """ +from __future__ import annotations + import logging import os import re @@ -24,7 +26,7 @@ import warnings from copy import deepcopy from datetime import timedelta -from typing import Any, Dict, Optional, Set +from typing import Any from weakref import proxy import torch @@ -207,19 +209,19 @@ class ModelCheckpoint(Checkpoint): def __init__( self, - dirpath: Optional[_PATH] = None, - filename: Optional[str] = None, - monitor: Optional[str] = None, + dirpath: _PATH | None = None, + filename: str | None = None, + monitor: str | None = None, verbose: bool = False, - save_last: Optional[bool] = None, + save_last: bool | None = None, save_top_k: int = 1, save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True, - every_n_train_steps: Optional[int] = None, - train_time_interval: Optional[timedelta] = None, - every_n_epochs: Optional[int] = None, - save_on_train_epoch_end: Optional[bool] = None, + every_n_train_steps: int | None = None, + train_time_interval: timedelta | None = None, + every_n_epochs: int | None = None, + save_on_train_epoch_end: bool | None = None, enable_version_counter: bool = True, ): super().__init__() @@ -232,16 +234,16 @@ def __init__( self._save_on_train_epoch_end = save_on_train_epoch_end self._enable_version_counter = enable_version_counter self._last_global_step_saved = 0 # no need to save when no steps were taken - self._last_time_checked: Optional[float] = None - self.current_score: Optional[Tensor] = None - self.best_k_models: Dict[str, Tensor] = {} + self._last_time_checked: float | None = None + self.current_score: Tensor | None = None + self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" - self.best_model_score: Optional[Tensor] = None + self.best_model_score: Tensor | None = None self.best_model_path = "" self.last_model_path = "" self.kth_value: Tensor - self.dirpath: Optional[_PATH] + self.dirpath: _PATH | None self.__init_monitor_mode(mode) self.__init_ckpt_dir(dirpath, filename) self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) @@ -257,20 +259,20 @@ def state_key(self) -> str: train_time_interval=self._train_time_interval, ) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: dirpath = self.__resolve_ckpt_dir(trainer) dirpath = trainer.strategy.broadcast(dirpath) self.dirpath = dirpath if trainer.is_global_zero and stage == "fit": self.__warn_if_dir_not_empty(self.dirpath) - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._last_time_checked = time.monotonic() def on_train_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, @@ -299,7 +301,7 @@ def on_train_batch_end( self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Save a checkpoint at the end of the training epoch.""" if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) @@ -307,7 +309,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Save a checkpoint at the end of the validation stage.""" if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) @@ -315,7 +317,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, @@ -328,7 +330,7 @@ def state_dict(self) -> Dict[str, Any]: "last_model_path": self.last_model_path, } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath) if self.dirpath == dirpath_from_ckpt: @@ -346,7 +348,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.best_model_path = state_dict["best_model_path"] - def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_topk_checkpoint(self, trainer: pl.Trainer, monitor_candidates: dict[str, Tensor]) -> None: if self.save_top_k == 0: return @@ -365,7 +367,7 @@ def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ else: self._save_none_monitor_checkpoint(trainer, monitor_candidates) - def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + def _save_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: trainer.save_checkpoint(filepath, self.save_weights_only) self._last_global_step_saved = trainer.global_step @@ -375,7 +377,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: for logger in trainer.loggers: logger.after_save_checkpoint(proxy(self)) - def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: + def _should_skip_saving_checkpoint(self, trainer: pl.Trainer) -> bool: from lightning.pytorch.trainer.states import TrainerFn return ( @@ -385,7 +387,7 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: or self._last_global_step_saved == trainer.global_step # already saved at the last step ) - def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: + def _should_save_on_train_epoch_end(self, trainer: pl.Trainer) -> bool: if self._save_on_train_epoch_end is not None: return self._save_on_train_epoch_end @@ -439,7 +441,7 @@ def __validate_init_configuration(self) -> None: " will duplicate the last checkpoint saved." ) - def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: + def __init_ckpt_dir(self, dirpath: _PATH | None, filename: str | None) -> None: self._fs = get_filesystem(dirpath if dirpath else "") if dirpath and self._fs.protocol == "file": @@ -459,9 +461,9 @@ def __init_monitor_mode(self, mode: str) -> None: def __init_triggers( self, - every_n_train_steps: Optional[int], - every_n_epochs: Optional[int], - train_time_interval: Optional[timedelta], + every_n_train_steps: int | None, + every_n_epochs: int | None, + train_time_interval: timedelta | None, ) -> None: # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_epochs is set @@ -473,15 +475,15 @@ def __init_triggers( every_n_epochs = every_n_epochs or 0 every_n_train_steps = every_n_train_steps or 0 - self._train_time_interval: Optional[timedelta] = train_time_interval + self._train_time_interval: timedelta | None = train_time_interval self._every_n_epochs: int = every_n_epochs self._every_n_train_steps: int = every_n_train_steps @property - def every_n_epochs(self) -> Optional[int]: + def every_n_epochs(self) -> int | None: return self._every_n_epochs - def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = None) -> bool: + def check_monitor_top_k(self, trainer: pl.Trainer, current: Tensor | None = None) -> bool: if current is None: return False @@ -503,8 +505,8 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = @classmethod def _format_checkpoint_name( cls, - filename: Optional[str], - metrics: Dict[str, Tensor], + filename: str | None, + metrics: dict[str, Tensor], prefix: str = "", auto_insert_metric_name: bool = True, ) -> str: @@ -538,7 +540,7 @@ def _format_checkpoint_name( return filename def format_checkpoint_name( - self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + self, metrics: dict[str, Tensor], filename: str | None = None, ver: int | None = None ) -> str: """Generate a filename according to the defined template. @@ -577,7 +579,7 @@ def format_checkpoint_name( ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: + def __resolve_ckpt_dir(self, trainer: pl.Trainer) -> _PATH: """Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to determine where to save checkpoints. The path for saving weights is set in this priority: @@ -606,7 +608,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: return ckpt_path - def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: + def _find_last_checkpoints(self, trainer: pl.Trainer) -> set[str]: # find all checkpoints in the folder ckpt_path = self.__resolve_ckpt_dir(trainer) if self._fs.exists(ckpt_path): @@ -622,7 +624,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") def _get_metric_interpolated_filepath_name( - self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None + self, monitor_candidates: dict[str, Tensor], trainer: pl.Trainer, del_filepath: str | None = None ) -> str: filepath = self.format_checkpoint_name(monitor_candidates) @@ -634,7 +636,7 @@ def _get_metric_interpolated_filepath_name( return filepath - def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: + def _monitor_candidates(self, trainer: pl.Trainer) -> dict[str, Tensor]: monitor_candidates = deepcopy(trainer.callback_metrics) # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor # or does not exist we overwrite it as it's likely an error @@ -644,7 +646,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step) return monitor_candidates - def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_last_checkpoint(self, trainer: pl.Trainer, monitor_candidates: dict[str, Tensor]) -> None: if not self.save_last: return @@ -662,7 +664,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ if previous and previous != filepath: self._remove_checkpoint(trainer, previous) - def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_monitor_checkpoint(self, trainer: pl.Trainer, monitor_candidates: dict[str, Tensor]) -> None: assert self.monitor current = monitor_candidates.get(self.monitor) if self.check_monitor_top_k(trainer, current): @@ -673,7 +675,7 @@ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Di step = monitor_candidates["step"] rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_none_monitor_checkpoint(self, trainer: pl.Trainer, monitor_candidates: dict[str, Tensor]) -> None: filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath @@ -682,7 +684,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate self._remove_checkpoint(trainer, previous) def _update_best_and_save( - self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] + self, current: Tensor, trainer: pl.Trainer, monitor_candidates: dict[str, Tensor] ) -> None: k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k @@ -723,7 +725,7 @@ def _update_best_and_save( if del_filepath is not None and filepath != del_filepath: self._remove_checkpoint(trainer, del_filepath) - def to_yaml(self, filepath: Optional[_PATH] = None) -> None: + def to_yaml(self, filepath: _PATH | None = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file.""" best_k = {k: v.item() for k, v in self.best_k_models.items()} @@ -733,12 +735,12 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None: with self._fs.open(filepath, "w") as fp: yaml.dump(best_k, fp) - def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: + def file_exists(self, filepath: _PATH, trainer: pl.Trainer) -> bool: """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.""" exists = self._fs.exists(filepath) return trainer.strategy.broadcast(exists) - def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: """Calls the strategy to remove the checkpoint file.""" trainer.strategy.remove_checkpoint(filepath) diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index 4fc788ba09226..52a4688930fe3 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -21,8 +21,10 @@ the name, type and number of parameters for each layer. """ +from __future__ import annotations + import logging -from typing import Any, Dict, List, Tuple, Union +from typing import Any import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -51,9 +53,9 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: self._max_depth: int = max_depth - self._summarize_kwargs: Dict[str, Any] = summarize_kwargs + self._summarize_kwargs: dict[str, Any] = summarize_kwargs - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if not self._max_depth: return @@ -66,7 +68,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if trainer.is_global_zero: self.summarize(summary_data, total_parameters, trainable_parameters, model_size, **self._summarize_kwargs) - def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Union[DeepSpeedSummary, Summary]: + def _summary(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> DeepSpeedSummary | Summary: from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy if isinstance(trainer.strategy, DeepSpeedStrategy) and trainer.strategy.zero_stage_3: @@ -75,7 +77,7 @@ def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Un @staticmethod def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, diff --git a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py index 760e774a256e8..b260d0ed80641 100644 --- a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py +++ b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py @@ -17,6 +17,8 @@ Automatically save a checkpoints on exception. """ +from __future__ import annotations + import os from typing import Any @@ -57,9 +59,9 @@ def __init__(self, dirpath: _PATH, filename: str = "on_exception") -> None: def ckpt_path(self) -> str: return os.path.join(self.dirpath, self.filename + self.FILE_EXTENSION) - def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: + def on_exception(self, trainer: pl.Trainer, *_: Any, **__: Any) -> None: # overwrite if necessary trainer.save_checkpoint(self.ckpt_path) - def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: + def teardown(self, trainer: pl.Trainer, *_: Any, **__: Any) -> None: trainer.strategy.remove_checkpoint(self.ckpt_path) diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index 0f19c771027d5..1b1923c04964a 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -17,7 +17,9 @@ Aids in saving predictions """ -from typing import Any, Literal, Optional, Sequence +from __future__ import annotations + +from typing import Any, Literal, Sequence import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -108,10 +110,10 @@ def __init__(self, write_interval: Literal["batch", "epoch", "batch_and_epoch"] def write_on_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, prediction: Any, - batch_indices: Optional[Sequence[int]], + batch_indices: Sequence[int] | None, batch: Any, batch_idx: int, dataloader_idx: int, @@ -121,18 +123,18 @@ def write_on_batch_end( def write_on_epoch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, predictions: Sequence[Any], - batch_indices: Optional[Sequence[Any]], + batch_indices: Sequence[Any] | None, ) -> None: """Override with the logic to write all batches.""" raise NotImplementedError() def on_predict_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, @@ -143,7 +145,7 @@ def on_predict_batch_end( batch_indices = trainer.predict_loop.current_batch_indices self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) - def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if not self.interval.on_epoch: return epoch_batch_indices = trainer.predict_loop.epoch_batch_indices diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 5a0adbf74163a..5ba8bf41419c0 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from __future__ import annotations + +from typing import Any import lightning.pytorch as pl from lightning.pytorch.callbacks import Callback @@ -45,11 +47,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx): """ def __init__(self) -> None: - self._trainer: Optional["pl.Trainer"] = None - self._current_eval_dataloader_idx: Optional[int] = None + self._trainer: pl.Trainer | None = None + self._current_eval_dataloader_idx: int | None = None @property - def trainer(self) -> "pl.Trainer": + def trainer(self) -> pl.Trainer: if self._trainer is None: raise TypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.") return self._trainer @@ -75,7 +77,7 @@ def predict_description(self) -> str: return "Predicting" @property - def total_train_batches(self) -> Union[int, float]: + def total_train_batches(self) -> int | float: """The total number of training batches, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training @@ -84,7 +86,7 @@ def total_train_batches(self) -> Union[int, float]: return self.trainer.num_training_batches @property - def total_val_batches_current_dataloader(self) -> Union[int, float]: + def total_val_batches_current_dataloader(self) -> int | float: """The total number of validation batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation @@ -97,7 +99,7 @@ def total_val_batches_current_dataloader(self) -> Union[int, float]: return batches @property - def total_test_batches_current_dataloader(self) -> Union[int, float]: + def total_test_batches_current_dataloader(self) -> int | float: """The total number of testing batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is @@ -110,7 +112,7 @@ def total_test_batches_current_dataloader(self) -> Union[int, float]: return batches @property - def total_predict_batches_current_dataloader(self) -> Union[int, float]: + def total_predict_batches_current_dataloader(self) -> int | float: """The total number of prediction batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader @@ -120,7 +122,7 @@ def total_predict_batches_current_dataloader(self) -> Union[int, float]: return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] @property - def total_val_batches(self) -> Union[int, float]: + def total_val_batches(self) -> int | float: """The total number of validation batches, which may change from epoch to epoch for all val dataloaders. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader @@ -159,14 +161,14 @@ def print(self, *args: Any, **kwargs: Any) -> None: """You should provide a way to print without breaking the progress bar.""" print(*args, **kwargs) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: self._trainer = trainer if not trainer.is_global_zero: self.disable() def get_metrics( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ) -> dict[str, int | str | float | dict[str, float]]: r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. @@ -196,7 +198,7 @@ def get_metrics(self, trainer, model): return {**standard_metrics, **pbar_metrics} -def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: +def get_standard_metrics(trainer: pl.Trainer) -> dict[str, int | str]: r"""Returns the standard metrics displayed in the progress bar. Currently, it only includes the version of the experiment when using a logger. @@ -207,7 +209,7 @@ def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: Return: Dictionary with the standard metrics to be displayed in the progress bar. """ - items_dict: Dict[str, Union[int, str]] = {} + items_dict: dict[str, int | str] = {} if trainer.loggers: from lightning.pytorch.loggers.utilities import _version diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 26e53c121573a..c070341423cec 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math from dataclasses import dataclass from datetime import timedelta -from typing import Any, cast, Dict, Optional, Union +from typing import Any, cast from lightning_utilities.core.imports import RequirementCache @@ -36,7 +38,7 @@ class CustomBarColumn(BarColumn): """Overrides ``BarColumn`` to provide support for dataloaders that do not define a size (infinite size) such as ``IterableDataset``.""" - def render(self, task: "Task") -> _RichProgressBar: + def render(self, task: Task) -> _RichProgressBar: """Gets a progress bar widget for a task.""" assert task.total is not None assert task.remaining is not None @@ -60,7 +62,7 @@ class CustomInfiniteTask(Task): """ @property - def time_remaining(self) -> Optional[float]: + def time_remaining(self) -> float | None: return None class CustomProgress(Progress): @@ -70,7 +72,7 @@ def add_task( self, description: str, start: bool = True, - total: Optional[float] = 100.0, + total: float | None = 100.0, completed: int = 0, visible: bool = True, **fields: Any, @@ -104,11 +106,11 @@ class CustomTimeColumn(ProgressColumn): # Only refresh twice a second to prevent jitter max_refresh = 0.5 - def __init__(self, style: Union[str, Style]) -> None: + def __init__(self, style: str | Style) -> None: self.style = style super().__init__() - def render(self, task: "Task") -> Text: + def render(self, task: Task) -> Text: elapsed = task.finished_time if task.finished else task.elapsed remaining = task.time_remaining elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed))) @@ -116,41 +118,41 @@ def render(self, task: "Task") -> Text: return Text(f"{elapsed_delta} • {remaining_delta}", style=self.style) class BatchesProcessedColumn(ProgressColumn): - def __init__(self, style: Union[str, Style]): + def __init__(self, style: str | Style): self.style = style super().__init__() - def render(self, task: "Task") -> RenderableType: + def render(self, task: Task) -> RenderableType: total = task.total if task.total != float("inf") else "--" return Text(f"{int(task.completed)}/{total}", style=self.style) class ProcessingSpeedColumn(ProgressColumn): - def __init__(self, style: Union[str, Style]): + def __init__(self, style: str | Style): self.style = style super().__init__() - def render(self, task: "Task") -> RenderableType: + def render(self, task: Task) -> RenderableType: task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00" return Text(f"{task_speed}it/s", style=self.style) class MetricsTextColumn(ProgressColumn): """A column containing text.""" - def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"]): + def __init__(self, trainer: pl.Trainer, style: str | Style): self._trainer = trainer - self._tasks: Dict[Union[int, TaskID], Any] = {} + self._tasks: dict[int | TaskID, Any] = {} self._current_task_id = 0 - self._metrics: Dict[Union[str, "Style"], Any] = {} + self._metrics: dict[str | Style, Any] = {} self._style = style super().__init__() - def update(self, metrics: Dict[Any, Any]) -> None: + def update(self, metrics: dict[Any, Any]) -> None: # Called when metrics are ready to be rendered. # This is to prevent render from causing deadlock issues by requesting metrics # in separate threads. self._metrics = metrics - def render(self, task: "Task") -> Text: + def render(self, task: Task) -> Text: assert isinstance(self._trainer.progress_bar_callback, RichProgressBar) if ( self._trainer.state.fn != "fit" @@ -193,14 +195,14 @@ class RichProgressBarTheme: https://rich.readthedocs.io/en/stable/style.html """ - description: Union[str, Style] = "white" - progress_bar: Union[str, Style] = "#6206E0" - progress_bar_finished: Union[str, Style] = "#6206E0" - progress_bar_pulse: Union[str, Style] = "#6206E0" - batch_progress: Union[str, Style] = "white" - time: Union[str, Style] = "grey54" - processing_speed: Union[str, Style] = "grey70" - metrics: Union[str, Style] = "white" + description: str | Style = "white" + progress_bar: str | Style = "#6206E0" + progress_bar_finished: str | Style = "#6206E0" + progress_bar_pulse: str | Style = "#6206E0" + batch_progress: str | Style = "white" + time: str | Style = "grey54" + processing_speed: str | Style = "grey70" + metrics: str | Style = "white" class RichProgressBar(ProgressBar): @@ -241,7 +243,7 @@ def __init__( refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), - console_kwargs: Optional[Dict[str, Any]] = None, + console_kwargs: dict[str, Any] | None = None, ) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( @@ -251,17 +253,17 @@ def __init__( super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave - self._console: Optional[Console] = None + self._console: Console | None = None self._console_kwargs = console_kwargs or {} self._enabled: bool = True - self.progress: Optional[CustomProgress] = None - self.train_progress_bar_id: Optional["TaskID"] - self.val_sanity_progress_bar_id: Optional["TaskID"] = None - self.val_progress_bar_id: Optional["TaskID"] - self.test_progress_bar_id: Optional["TaskID"] - self.predict_progress_bar_id: Optional["TaskID"] + self.progress: CustomProgress | None = None + self.train_progress_bar_id: TaskID | None + self.val_sanity_progress_bar_id: TaskID | None = None + self.val_progress_bar_id: TaskID | None + self.test_progress_bar_id: TaskID | None + self.predict_progress_bar_id: TaskID | None self._reset_progress_bar_ids() - self._metric_component: Optional["MetricsTextColumn"] = None + self._metric_component: MetricsTextColumn | None = None self._progress_stopped: bool = False self.theme = theme self._update_for_light_colab_theme() @@ -315,7 +317,7 @@ def disable(self) -> None: def enable(self) -> None: self._enabled = True - def _init_progress(self, trainer: "pl.Trainer") -> None: + def _init_progress(self, trainer: pl.Trainer) -> None: if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() reconfigure(**self._console_kwargs) @@ -337,28 +339,28 @@ def refresh(self) -> None: if self.progress: self.progress.refresh() - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._init_progress(trainer) - def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._init_progress(trainer) - def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._init_progress(trainer) - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._init_progress(trainer) - def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_sanity_check_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._init_progress(trainer) - def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_sanity_check_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self.progress is not None: assert self.val_sanity_progress_bar_id is not None self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self.is_disabled: return total_batches = self.total_train_batches @@ -379,8 +381,8 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def on_validation_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -408,11 +410,11 @@ def on_validation_batch_start( self.refresh() - def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID": + def _add_task(self, total_batches: int | float, description: str, visible: bool = True) -> TaskID: assert self.progress is not None return self.progress.add_task(f"[{self.theme.description}]{description}", total=total_batches, visible=visible) - def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: + def _update(self, progress_bar_id: TaskID | None, current: int, visible: bool = True) -> None: if self.progress is not None and self.is_enabled: assert progress_bar_id is not None total = self.progress.tasks[progress_bar_id].total @@ -425,30 +427,30 @@ def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bo self.progress.update(progress_bar_id, advance=advance, visible=visible) self.refresh() - def _should_update(self, current: int, total: Union[int, float]) -> bool: + def _should_update(self, current: int, total: int | float) -> bool: return current % self.refresh_rate == 0 or current == total - def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self.is_enabled and self.val_progress_bar_id is not None and trainer.state.fn == "fit": assert self.progress is not None self.progress.update(self.val_progress_bar_id, advance=0, visible=False) self.refresh() - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if trainer.state.fn == "fit": self._update_metrics(trainer, pl_module) self.reset_dataloader_idx_tracker() - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.reset_dataloader_idx_tracker() - def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.reset_dataloader_idx_tracker() def on_test_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -464,8 +466,8 @@ def on_test_batch_start( def on_predict_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -482,20 +484,20 @@ def on_predict_batch_start( self.refresh() def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) self.refresh() - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._update_metrics(trainer, pl_module) def on_validation_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -510,9 +512,9 @@ def on_validation_batch_end( def on_test_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -525,8 +527,8 @@ def on_test_batch_end( def on_predict_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, @@ -561,18 +563,18 @@ def _reset_progress_bar_ids(self) -> None: self.test_progress_bar_id = None self.predict_progress_bar_id = None - def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def _update_metrics(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: metrics = self.get_metrics(trainer, pl_module) if self._metric_component: self._metric_component.update(metrics) - def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: self._stop_progress() - def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + def on_exception(self, trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException) -> None: self._stop_progress() - def configure_columns(self, trainer: "pl.Trainer") -> list: + def configure_columns(self, trainer: pl.Trainer) -> list: return [ TextColumn("[progress.description]{task.description}"), CustomBarColumn( @@ -585,7 +587,7 @@ def configure_columns(self, trainer: "pl.Trainer") -> list: ProcessingSpeedColumn(style=self.theme.processing_speed), ] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() # both the console and progress object can hold thread lock objects that are not pickleable state["progress"] = None diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 1b903fb55572b..988dda1afb9d8 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import importlib import math import os import sys -from typing import Any, Dict, Optional, Union +from typing import Any from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -42,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @staticmethod - def format_num(n: Union[int, float, str]) -> str: + def format_num(n: int | float | str) -> str: """Add additional padding to the formatted numbers.""" should_be_padded = isinstance(n, (float, str)) if not isinstance(n, str): @@ -105,12 +107,12 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self._refresh_rate = self._resolve_refresh_rate(refresh_rate) self._process_position = process_position self._enabled = True - self._train_progress_bar: Optional[_tqdm] = None - self._val_progress_bar: Optional[_tqdm] = None - self._test_progress_bar: Optional[_tqdm] = None - self._predict_progress_bar: Optional[_tqdm] = None + self._train_progress_bar: _tqdm | None = None + self._val_progress_bar: _tqdm | None = None + self._test_progress_bar: _tqdm | None = None + self._predict_progress_bar: _tqdm | None = None - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: # can't pickle the tqdm objects return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} @@ -246,34 +248,34 @@ def on_sanity_check_end(self, *_: Any) -> None: def on_train_start(self, *_: Any) -> None: self.train_progress_bar = self.init_train_tqdm() - def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, *_: Any) -> None: self.train_progress_bar.reset(convert_inf(self.total_train_batches)) self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: n = batch_idx + 1 if self._should_update(n, self.train_progress_bar.total): _update_n(self.train_progress_bar, n) self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if not self.train_progress_bar.disable: self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_end(self, *_: Any) -> None: self.train_progress_bar.close() - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if not trainer.sanity_checking: self.val_progress_bar = self.init_validation_tqdm() def on_validation_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -288,9 +290,9 @@ def on_validation_batch_start( def on_validation_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -299,19 +301,19 @@ def on_validation_batch_end( if self._should_update(n, self.val_progress_bar.total): _update_n(self.val_progress_bar, n) - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if self._train_progress_bar is not None and trainer.state.fn == "fit": self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() self.reset_dataloader_idx_tracker() - def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.test_progress_bar = self.init_test_tqdm() def on_test_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -325,9 +327,9 @@ def on_test_batch_start( def on_test_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Optional[STEP_OUTPUT], + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -336,17 +338,17 @@ def on_test_batch_end( if self._should_update(n, self.test_progress_bar.total): _update_n(self.test_progress_bar, n) - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.test_progress_bar.close() self.reset_dataloader_idx_tracker() - def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.predict_progress_bar = self.init_predict_tqdm() def on_predict_batch_start( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0, @@ -360,8 +362,8 @@ def on_predict_batch_start( def on_predict_batch_end( self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", + trainer: pl.Trainer, + pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, @@ -371,7 +373,7 @@ def on_predict_batch_end( if self._should_update(n, self.predict_progress_bar.total): _update_n(self.predict_progress_bar, n) - def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_predict_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self.predict_progress_bar.close() self.reset_dataloader_idx_tracker() @@ -403,7 +405,7 @@ def _resolve_refresh_rate(refresh_rate: int) -> int: return refresh_rate -def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: +def convert_inf(x: int | float | None) -> int | float | None: """The tqdm doesn't support inf/nan values. We have to convert it to None. diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 5cffe6171707e..0bb7959fbf9e2 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. r"""ModelPruning ^^^^^^^^^^^^""" +from __future__ import annotations + import inspect import logging from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Sequence, Tuple import torch.nn.utils.prune as pytorch_prune from lightning_utilities.core.apply_func import apply_to_collection @@ -52,7 +54,7 @@ class _LayerRef(TypedDict): data: nn.Module - names: List[Tuple[int, str]] + names: list[tuple[int, str]] class ModelPruning(Callback): @@ -60,17 +62,17 @@ class ModelPruning(Callback): def __init__( self, - pruning_fn: Union[Callable, str], + pruning_fn: Callable | str, parameters_to_prune: _PARAM_LIST = (), - parameter_names: Optional[List[str]] = None, + parameter_names: list[str] | None = None, use_global_unstructured: bool = True, - amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, - apply_pruning: Union[bool, Callable[[int], bool]] = True, + amount: int | float | Callable[[int], int | float] = 0.5, + apply_pruning: bool | Callable[[int], bool] = True, make_pruning_permanent: bool = True, - use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True, + use_lottery_ticket_hypothesis: bool | Callable[[int], bool] = True, resample_parameters: bool = False, - pruning_dim: Optional[int] = None, - pruning_norm: Optional[int] = None, + pruning_dim: int | None = None, + pruning_norm: int | None = None, verbose: int = 0, prune_on_train_epoch_end: bool = True, ) -> None: @@ -160,9 +162,9 @@ def __init__( self._resample_parameters = resample_parameters self._prune_on_train_epoch_end = prune_on_train_epoch_end self._parameter_names = parameter_names or self.PARAMETER_NAMES - self._global_kwargs: Dict[str, Any] = {} - self._original_layers: Optional[Dict[int, _LayerRef]] = None - self._pruning_method_name: Optional[str] = None + self._global_kwargs: dict[str, Any] = {} + self._original_layers: dict[int, _LayerRef] | None = None + self._pruning_method_name: str | None = None for name in self._parameter_names: if name not in self.PARAMETER_NAMES: @@ -230,7 +232,7 @@ def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _ """This function can be overridden to control which module to prune.""" return parameters_to_prune - def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]: + def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Callable | pytorch_prune.BasePruningMethod: """This function takes `pruning_fn`, a function name. IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE, @@ -302,7 +304,7 @@ def _apply_local_pruning(self, amount: float) -> None: for module, name in self._parameters_to_prune: self.pruning_fn(module, name=name, amount=amount) - def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]: + def _resolve_global_kwargs(self, amount: float) -> dict[str, Any]: self._global_kwargs["amount"] = amount params = set(inspect.signature(self.pruning_fn).parameters) params.discard("self") @@ -314,14 +316,14 @@ def _apply_global_pruning(self, amount: float) -> None: ) @staticmethod - def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: + def _get_pruned_stats(module: nn.Module, name: str) -> tuple[int, int]: attr = f"{name}_mask" if not hasattr(module, attr): return 0, 1 mask = getattr(module, attr) return (mask == 0).sum().item(), mask.numel() - def apply_pruning(self, amount: Union[int, float]) -> None: + def apply_pruning(self, amount: int | float) -> None: """Applies pruning to ``parameters_to_prune``.""" if self._verbose: prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] @@ -337,7 +339,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None: @rank_zero_only def _log_sparsity_stats( - self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 + self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: int | float = 0 ) -> None: total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) prev_total_zeros = sum(zeros for zeros, _ in prev) @@ -357,7 +359,7 @@ def _log_sparsity_stats( f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) - def setup(self, trainer: "pl.Trainer", pl_module: LightningModule, stage: str) -> None: + def setup(self, trainer: pl.Trainer, pl_module: LightningModule, stage: str) -> None: parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names ) @@ -387,22 +389,22 @@ def _run_pruning(self, current_epoch: int) -> None: ): self.apply_lottery_ticket_hypothesis() - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: LightningModule) -> None: if self._prune_on_train_epoch_end: rank_zero_debug("`ModelPruning.on_train_epoch_end`. Applying pruning") self._run_pruning(pl_module.current_epoch) - def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: LightningModule) -> None: if not trainer.sanity_checking and not self._prune_on_train_epoch_end: rank_zero_debug("`ModelPruning.on_validation_epoch_end`. Applying pruning") self._run_pruning(pl_module.current_epoch) - def on_train_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> None: + def on_train_end(self, trainer: pl.Trainer, pl_module: LightningModule) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint") self.make_pruning_permanent(pl_module) - def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> Dict[str, Any]: + def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> dict[str, Any]: state_dict = pl_module.state_dict() # find the mask and the original weights. @@ -419,7 +421,7 @@ def move_to_cpu(tensor: Tensor) -> Tensor: return apply_to_collection(state_dict, Tensor, move_to_cpu) - def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, trainer: pl.Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint") # manually prune the weights so training can keep going with the same buffers diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index f68c98259b112..c682dc0ad0dca 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from __future__ import annotations + +from typing import Any from lightning.pytorch.callbacks import ModelSummary from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE @@ -67,7 +69,7 @@ def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: @staticmethod def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index dc6ae074a31b4..ebe441b40d077 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. r"""Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^""" +from __future__ import annotations + from copy import deepcopy -from typing import Any, Callable, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast import torch from torch import nn, Tensor @@ -34,12 +36,12 @@ class StochasticWeightAveraging(Callback): def __init__( self, - swa_lrs: Union[float, List[float]], - swa_epoch_start: Union[int, float] = 0.8, + swa_lrs: float | list[float], + swa_epoch_start: int | float = 0.8, annealing_epochs: int = 10, annealing_strategy: str = "cos", - avg_fn: Optional[_AVG_FN] = None, - device: Optional[Union[torch.device, str]] = torch.device("cpu"), + avg_fn: _AVG_FN | None = None, + device: torch.device | str | None = torch.device("cpu"), ): r"""Implements the Stochastic Weight Averaging (SWA) Callback to average a model. @@ -109,21 +111,21 @@ def __init__( if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") - self.n_averaged: Optional[Tensor] = None + self.n_averaged: Tensor | None = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device - self._model_contains_batch_norm: Optional[bool] = None - self._average_model: Optional["pl.LightningModule"] = None + self._model_contains_batch_norm: bool | None = None + self._average_model: pl.LightningModule | None = None self._initialized = False - self._swa_scheduler: Optional[LRScheduler] = None - self._scheduler_state: Optional[Dict] = None + self._swa_scheduler: LRScheduler | None = None + self._scheduler_state: dict | None = None self._init_n_averaged = 0 self._latest_update_epoch = -1 - self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} + self.momenta: dict[nn.modules.batchnorm._BatchNorm, float | None] = {} self._max_epochs: int @property @@ -136,17 +138,17 @@ def swa_end(self) -> int: return self._max_epochs - 1 # 0-based @staticmethod - def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: + def pl_module_contains_batch_norm(pl_module: pl.LightningModule) -> bool: return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) - def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)): raise MisconfigurationException("SWA does not currently support sharded models.") # copy the model before moving it to accelerator device. self._average_model = deepcopy(pl_module) - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if len(trainer.optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") @@ -168,7 +170,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - if self._scheduler_state is not None: self._clear_schedulers(trainer) - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): self._initialized = True @@ -251,10 +253,10 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo assert isinstance(trainer.fit_loop.max_batches, int), "Iterable-style datasets are not supported" trainer.accumulate_grad_batches = trainer.fit_loop.max_batches - def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, *args: Any) -> None: trainer.fit_loop._skip_backward = False - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: # the trainer increases the current epoch before this hook is called if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: # BatchNorm epoch update. Reset state @@ -269,11 +271,11 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self.transfer_weights(self._average_model, pl_module) @staticmethod - def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None: + def transfer_weights(src_pl_module: pl.LightningModule, dst_pl_module: pl.LightningModule) -> None: for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device)) - def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None: + def reset_batch_norm_and_save_state(self, pl_module: pl.LightningModule) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" self.momenta = {} for module in pl_module.modules(): @@ -303,7 +305,7 @@ def reset_momenta(self) -> None: @staticmethod def update_parameters( - average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN + average_model: pl.LightningModule, model: pl.LightningModule, n_averaged: Tensor, avg_fn: _AVG_FN ) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" for p_swa, p_model in zip(average_model.parameters(), model.parameters()): @@ -319,7 +321,7 @@ def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averag """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, @@ -327,14 +329,14 @@ def state_dict(self) -> Dict[str, Any]: "average_model_state": None if self._average_model is None else self._average_model.state_dict(), } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._init_n_averaged = state_dict["n_averaged"] self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] self._load_average_model_state(state_dict["average_model_state"]) @staticmethod - def _clear_schedulers(trainer: "pl.Trainer") -> None: + def _clear_schedulers(trainer: pl.Trainer) -> None: # If we have scheduler state saved, clear the scheduler configs so that we don't try to # load state into the wrong type of schedulers when restoring scheduler checkpoint state. # We'll configure the scheduler and re-load its state in on_train_epoch_start. diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py index bb6245dbb00ea..4f294de7ca5e9 100644 --- a/src/lightning/pytorch/callbacks/timer.py +++ b/src/lightning/pytorch/callbacks/timer.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. r"""Timer ^^^^^""" +from __future__ import annotations + import logging import time from datetime import timedelta -from typing import Any, Dict, Optional, Union +from typing import Any import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -73,7 +75,7 @@ class Timer(Callback): def __init__( self, - duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, + duration: str | timedelta | dict[str, int] | None = None, interval: str = Interval.step, verbose: bool = True, ) -> None: @@ -92,16 +94,16 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} - self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._start_time: dict[RunningStage, float | None] = {stage: None for stage in RunningStage} + self._end_time: dict[RunningStage, float | None] = {stage: None for stage in RunningStage} self._offset = 0 - def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + def start_time(self, stage: str = RunningStage.TRAINING) -> float | None: """Return the start time of a particular stage (in seconds)""" stage = RunningStage(stage) return self._start_time[stage] - def end_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + def end_time(self, stage: str = RunningStage.TRAINING) -> float | None: """Return the end time of a particular stage (in seconds)""" stage = RunningStage(stage) return self._end_time[stage] @@ -117,55 +119,55 @@ def time_elapsed(self, stage: str = RunningStage.TRAINING) -> float: return time.monotonic() - start + offset return end - start + offset - def time_remaining(self, stage: str = RunningStage.TRAINING) -> Optional[float]: + def time_remaining(self, stage: str = RunningStage.TRAINING) -> float | None: """Return the time remaining for a particular stage (in seconds)""" if self._duration is not None: return self._duration - self.time_elapsed(stage) return None - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._start_time[RunningStage.TRAINING] = time.monotonic() - def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._end_time[RunningStage.TRAINING] = time.monotonic() - def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._start_time[RunningStage.VALIDATING] = time.monotonic() - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._end_time[RunningStage.VALIDATING] = time.monotonic() - def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._start_time[RunningStage.TESTING] = time.monotonic() - def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: self._end_time[RunningStage.TESTING] = time.monotonic() - def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + def on_fit_start(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: # this checks the time after the state is reloaded, regardless of the interval. # this is necessary in case we load a state whose timer is already depleted if self._duration is None: return self._check_time_remaining(trainer) - def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + def on_train_batch_end(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: if self._interval != Interval.step or self._duration is None: return self._check_time_remaining(trainer) - def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: if self._interval != Interval.epoch or self._duration is None: return self._check_time_remaining(trainer) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) - def _check_time_remaining(self, trainer: "pl.Trainer") -> None: + def _check_time_remaining(self, trainer: pl.Trainer) -> None: assert self._duration is not None should_stop = self.time_elapsed() >= self._duration should_stop = trainer.strategy.broadcast(should_stop) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index c1458c90debf0..bac71f6210e94 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import sys from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -92,24 +94,24 @@ def __init__( if not _JSONARGPARSE_SIGNATURES_AVAILABLE: raise ModuleNotFoundError(f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}") super().__init__(*args, description=description, env_prefix=env_prefix, default_env=default_env, **kwargs) - self.callback_keys: List[str] = [] + self.callback_keys: list[str] = [] # separate optimizers and lr schedulers to know which were added - self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + self._optimizers: dict[str, tuple[type | tuple[type, ...], str]] = {} + self._lr_schedulers: dict[str, tuple[type | tuple[type, ...], str]] = {} def add_lightning_class_args( self, - lightning_class: Union[ - Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], - Type[Trainer], - Type[LightningModule], - Type[LightningDataModule], - Type[Callback], - ], + lightning_class: ( + Callable[..., Trainer | LightningModule | LightningDataModule | Callback] + | type[Trainer] + | type[LightningModule] + | type[LightningDataModule] + | type[Callback] + ), nested_key: str, subclass_mode: bool = False, required: bool = True, - ) -> List[str]: + ) -> list[str]: """Adds arguments from a lightning class to a nested key of the parser. Args: @@ -145,7 +147,7 @@ def add_lightning_class_args( def add_optimizer_args( self, - optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), + optimizer_class: type[Optimizer] | tuple[type[Optimizer], ...] = (Optimizer,), nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: @@ -160,7 +162,7 @@ def add_optimizer_args( assert all(issubclass(o, Optimizer) for o in optimizer_class) else: assert issubclass(optimizer_class, Optimizer) - kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} + kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) else: @@ -169,7 +171,7 @@ def add_optimizer_args( def add_lr_scheduler_args( self, - lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, + lr_scheduler_class: LRSchedulerType | tuple[LRSchedulerType, ...] = LRSchedulerTypeTuple, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: @@ -185,7 +187,7 @@ def add_lr_scheduler_args( assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) - kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} + kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) else: @@ -292,14 +294,14 @@ class LightningCLI: def __init__( self, - model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, - datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, - save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, - save_config_kwargs: Optional[Dict[str, Any]] = None, - trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, - trainer_defaults: Optional[Dict[str, Any]] = None, - seed_everything_default: Union[bool, int] = True, - parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, + model_class: type[LightningModule] | Callable[..., LightningModule] | None = None, + datamodule_class: type[LightningDataModule] | Callable[..., LightningDataModule] | None = None, + save_config_callback: type[SaveConfigCallback] | None = SaveConfigCallback, + save_config_kwargs: dict[str, Any] | None = None, + trainer_class: type[Trainer] | Callable[..., Trainer] = Trainer, + trainer_defaults: dict[str, Any] | None = None, + seed_everything_default: bool | int = True, + parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, @@ -378,7 +380,7 @@ def __init__( if self.subcommand is not None: self._run_subcommand(self.subcommand) - def _setup_parser_kwargs(self, parser_kwargs: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: subcommand_names = self.subcommands().keys() main_kwargs = {k: v for k, v in parser_kwargs.items() if k not in subcommand_names} subparser_kwargs = {k: v for k, v in parser_kwargs.items() if k in subcommand_names} @@ -394,12 +396,12 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: return parser def setup_parser( - self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] + self, add_subcommands: bool, main_kwargs: dict[str, Any], subparser_kwargs: dict[str, Any] ) -> None: """Initialize and setup the parser, subcommands, and arguments.""" self.parser = self.init_parser(**main_kwargs) if add_subcommands: - self._subcommand_method_arguments: Dict[str, List[str]] = {} + self._subcommand_method_arguments: dict[str, list[str]] = {} self._add_subcommands(self.parser, **subparser_kwargs) else: self._add_arguments(self.parser) @@ -453,7 +455,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """ @staticmethod - def subcommands() -> Dict[str, Set[str]]: + def subcommands() -> dict[str, set[str]]: """Defines the list of available subcommands and the arguments to skip.""" return { "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, @@ -464,7 +466,7 @@ def subcommands() -> Dict[str, Set[str]]: def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: """Adds subcommands to the input parser.""" - self._subcommand_parsers: Dict[str, LightningArgumentParser] = {} + self._subcommand_parsers: dict[str, LightningArgumentParser] = {} parser_subcommands = parser.add_subcommands() # the user might have passed a builder function trainer_class = ( @@ -481,11 +483,11 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No self._subcommand_parsers[subcommand] = subcommand_parser parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description) - def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: + def _prepare_subcommand_parser(self, klass: type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: parser = self.init_parser(**kwargs) self._add_arguments(parser) # subcommand arguments - skip: Set[Union[str, int]] = set(self.subcommands()[subcommand]) + skip: set[str | int] = set(self.subcommands()[subcommand]) added = parser.add_method_arguments(klass, subcommand, skip=skip) # need to save which arguments were added to pass them to the method later self._subcommand_method_arguments[subcommand] = added @@ -538,7 +540,7 @@ def instantiate_trainer(self, **kwargs: Any) -> Trainer: trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} return self._instantiate_trainer(trainer_config, extra_callbacks) - def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: + def _instantiate_trainer(self, config: dict[str, Any], callbacks: list[Callback]) -> Trainer: key = "callbacks" if key in config: if config[key] is None: @@ -563,7 +565,7 @@ def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback] ) return self.trainer_class(**config) - def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: + def _parser(self, subcommand: str | None) -> LightningArgumentParser: if subcommand is None: return self.parser # return the subcommand parser for the subcommand passed @@ -571,7 +573,7 @@ def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: @staticmethod def configure_optimizers( - lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None + lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: LRSchedulerTypeUnion | None = None ) -> Any: """Override to customize the :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method. @@ -590,7 +592,7 @@ def configure_optimizers( } return [optimizer], [lr_scheduler] - def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: + def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None: """Overrides the model's :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" if not self.auto_configure_optimizers: @@ -599,8 +601,8 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - parser = self._parser(subcommand) def get_automatic( - class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] - ) -> List[str]: + class_type: type | tuple[type, ...], register: dict[str, tuple[type | tuple[type, ...], str]] + ) -> list[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): @@ -652,7 +654,7 @@ def get_automatic( # override the existing method self.model.configure_optimizers = MethodType(fn, self.model) - def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any: + def _get(self, config: Namespace, key: str, default: Any | None = None) -> Any: """Utility to get a config value which might be inside a subcommand.""" return config.get(str(self.subcommand), config).get(key, default) @@ -671,7 +673,7 @@ def _run_subcommand(self, subcommand: str) -> None: if callable(after_fn): after_fn() - def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: + def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]: """Prepares the keyword arguments to pass to the subcommand to run.""" fn_kwargs = { k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand] @@ -694,26 +696,24 @@ def _set_seed(self) -> None: self.config["seed_everything"] = config_seed -def _class_path_from_class(class_type: Type) -> str: +def _class_path_from_class(class_type: type) -> str: return class_type.__module__ + "." + class_type.__name__ -def _global_add_class_path( - class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None -) -> Dict[str, Any]: +def _global_add_class_path(class_type: type, init_args: Namespace | dict[str, Any] | None = None) -> dict[str, Any]: if isinstance(init_args, Namespace): init_args = init_args.as_dict() return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} -def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: - def add_class_path(init_args: Namespace) -> Dict[str, Any]: +def _add_class_path_generator(class_type: type) -> Callable[[Namespace], dict[str, Any]]: + def add_class_path(init_args: Namespace) -> dict[str, Any]: return _global_add_class_path(class_type, init_args) return add_class_path -def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: +def instantiate_class(args: Any | tuple[Any, ...], init: dict[str, Any]) -> Any: """Instantiates a class with the given args and init. Args: @@ -732,7 +732,7 @@ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) - return args_class(*args, **kwargs) -def _get_short_description(component: object) -> Optional[str]: +def _get_short_description(component: object) -> str | None: if component.__doc__ is None: return None try: diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 025c7b52130b8..00c3e16e617fa 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" +from __future__ import annotations + import inspect -from typing import Any, cast, Dict, IO, Iterable, Optional, Union +from typing import Any, cast, IO, Iterable from lightning_utilities import apply_to_collection from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -64,7 +66,7 @@ def teardown(self): ... """ - name: Optional[str] = None + name: str | None = None CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters" CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name" CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type" @@ -72,19 +74,19 @@ def teardown(self): def __init__(self) -> None: super().__init__() # Pointer to the trainer object - self.trainer: Optional["pl.Trainer"] = None + self.trainer: pl.Trainer | None = None @classmethod def from_datasets( cls, - train_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, - val_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, - test_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, - predict_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None, + train_dataset: Dataset | Iterable[Dataset] | None = None, + val_dataset: Dataset | Iterable[Dataset] | None = None, + test_dataset: Dataset | Iterable[Dataset] | None = None, + predict_dataset: Dataset | Iterable[Dataset] | None = None, batch_size: int = 1, num_workers: int = 0, **datamodule_kwargs: Any, - ) -> "LightningDataModule": + ) -> LightningDataModule: r"""Create an instance from torch.utils.data.Dataset. Args: @@ -137,7 +139,7 @@ def predict_dataloader() -> EVAL_DATALOADERS: datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign] return datamodule - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate and save datamodule state. Returns: @@ -145,7 +147,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict. Args: @@ -156,8 +158,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @classmethod def load_from_checkpoint( cls, - checkpoint_path: Union[_PATH, IO], - hparams_file: Optional[_PATH] = None, + checkpoint_path: _PATH | IO, + hparams_file: _PATH | None = None, **kwargs: Any, ) -> Self: r""" diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index df11b48f9b232..c5448dc8c2c9f 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -13,7 +13,9 @@ # limitations under the License. """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any import torch from torch import Tensor @@ -63,7 +65,7 @@ def on_predict_start(self) -> None: def on_predict_end(self) -> None: """Called at the end of predicting.""" - def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]: + def on_train_batch_start(self, batch: Any, batch_idx: int) -> int | None: """Called in the training loop before anything happens for that batch. If you return -1 here, you will skip training for the rest of the current epoch. @@ -92,7 +94,7 @@ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: """ def on_validation_batch_end( - self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Called in the validation loop after the batch. @@ -113,7 +115,7 @@ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = """ def on_test_batch_end( - self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int = 0 + self, outputs: STEP_OUTPUT | None, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Called in the test loop after the batch. @@ -133,7 +135,7 @@ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int dataloader_idx: the index of the dataloader """ - def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: + def on_predict_batch_end(self, outputs: Any | None, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Called in the predict loop after the batch. Args: @@ -624,7 +626,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): class CheckpointHooks: """Hooks to be used with Checkpointing.""" - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. @@ -642,7 +644,7 @@ def on_load_checkpoint(self, checkpoint): There is no need for you to restore anything regarding training. """ - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save. diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index d30caeda6b59c..680efcfad2f98 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import inspect import types from argparse import Namespace -from typing import Any, List, MutableMapping, Optional, Sequence, Union +from typing import Any, MutableMapping, Sequence from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters @@ -24,7 +26,7 @@ class HyperparametersMixin: - __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] + __jit_unused_properties__: list[str] = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() @@ -33,8 +35,8 @@ def __init__(self) -> None: def save_hyperparameters( self, *args: Any, - ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None, + ignore: Sequence[str] | str | None = None, + frame: types.FrameType | None = None, logger: bool = True, ) -> None: """Save arguments to ``hparams`` attribute. @@ -110,7 +112,7 @@ class ``__init__`` to be ignored frame = current_frame.f_back save_hyperparameters(self, *args, ignore=ignore, frame=frame) - def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: + def _set_hparams(self, hp: MutableMapping | Namespace | str) -> None: hp = self._to_hparams_dict(hp) if isinstance(hp, dict) and isinstance(self.hparams, dict): @@ -119,7 +121,7 @@ def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: self._hparams = hp @staticmethod - def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]: + def _to_hparams_dict(hp: MutableMapping | Namespace | str) -> MutableMapping | AttributeDict: if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): @@ -131,7 +133,7 @@ def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[Mutable return hp @property - def hparams(self) -> Union[AttributeDict, MutableMapping]: + def hparams(self) -> AttributeDict | MutableMapping: """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 92aff9d367a78..fc0ea52096ee5 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -12,28 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """The LightningModule - an nn.Module with many additional features.""" +from __future__ import annotations + import logging import numbers import operator import weakref from contextlib import contextmanager from pathlib import Path -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - IO, - List, - Literal, - Mapping, - Optional, - overload, - Sequence, - Tuple, - Union, -) +from typing import Any, Callable, cast, Generator, IO, List, Literal, Mapping, overload, Sequence, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -88,7 +75,7 @@ class LightningModule( ): # Below is for property support of JIT # since none of these are important when using JIT, we are going to ignore them. - __jit_unused_properties__: List[str] = ( + __jit_unused_properties__: list[str] = ( [ "example_input_array", "on_gpu", @@ -115,27 +102,27 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # pointer to the trainer object - self._trainer: Optional["pl.Trainer"] = None + self._trainer: pl.Trainer | None = None # optionally can be set by user - self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None - self._current_fx_name: Optional[str] = None + self._example_input_array: Tensor | tuple | dict | None = None + self._current_fx_name: str | None = None self._automatic_optimization: bool = True - self._param_requires_grad_state: Dict[str, bool] = {} - self._metric_attributes: Optional[Dict[int, str]] = None + self._param_requires_grad_state: dict[str, bool] = {} + self._metric_attributes: dict[int, str] | None = None self._register_sharded_tensor_state_dict_hooks_if_available() - self._compiler_ctx: Optional[Dict[str, Any]] = None + self._compiler_ctx: dict[str, Any] | None = None # attributes only used when using fabric - self._fabric: Optional["lf.Fabric"] = None - self._fabric_optimizers: List[_FabricOptimizer] = [] + self._fabric: lf.Fabric | None = None + self._fabric_optimizers: list[_FabricOptimizer] = [] @overload - def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]: + def optimizers(self, use_pl_optimizer: Literal[True] = True) -> LightningOptimizer | list[LightningOptimizer]: ... @overload - def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: + def optimizers(self, use_pl_optimizer: Literal[False]) -> Optimizer | list[Optimizer]: ... @overload @@ -170,7 +157,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: # multiple opts return opts - def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]: + def lr_schedulers(self) -> None | list[LRSchedulerPLType] | LRSchedulerPLType: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. @@ -182,7 +169,7 @@ def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLTyp return None # ignore other keys "interval", "frequency", etc. - lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] + lr_schedulers: list[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] # single scheduler if len(lr_schedulers) == 1: @@ -192,7 +179,7 @@ def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLTyp return lr_schedulers @property - def trainer(self) -> "pl.Trainer": + def trainer(self) -> pl.Trainer: if self._fabric is not None: return _TrainerFabricShim(fabric=self._fabric) # type: ignore[return-value] if not self._jit_is_scripting and self._trainer is None: @@ -200,7 +187,7 @@ def trainer(self) -> "pl.Trainer": return self._trainer # type: ignore[return-value] @trainer.setter - def trainer(self, trainer: Optional["pl.Trainer"]) -> None: + def trainer(self, trainer: pl.Trainer | None) -> None: for v in self.children(): if isinstance(v, LightningModule): v.trainer = trainer # type: ignore[assignment] @@ -210,11 +197,11 @@ def trainer(self, trainer: Optional["pl.Trainer"]) -> None: self._trainer = trainer @property - def fabric(self) -> Optional["lf.Fabric"]: + def fabric(self) -> lf.Fabric | None: return self._fabric @fabric.setter - def fabric(self, fabric: Optional["lf.Fabric"]) -> None: + def fabric(self, fabric: lf.Fabric | None) -> None: for v in self.children(): if isinstance(v, LightningModule): v.fabric = fabric @@ -223,7 +210,7 @@ def fabric(self, fabric: Optional["lf.Fabric"]) -> None: self._fabric = fabric @property - def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: + def example_input_array(self) -> Tensor | tuple | dict | None: """The example input array is a specification of what the module can consume in the :meth:`forward` method. The return type is interpreted as follows: @@ -237,7 +224,7 @@ def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: return self._example_input_array @example_input_array.setter - def example_input_array(self, example: Optional[Union[Tensor, Tuple, Dict]]) -> None: + def example_input_array(self, example: Tensor | tuple | dict | None) -> None: self._example_input_array = example @property @@ -281,14 +268,14 @@ def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization @property - def logger(self) -> Optional[Union[Logger, FabricLogger]]: + def logger(self) -> Logger | FabricLogger | None: """Reference to the logger object in the Trainer.""" if self._fabric is not None: return self._fabric.logger return self._trainer.logger if self._trainer is not None else None @property - def loggers(self) -> Union[List[Logger], List[FabricLogger]]: + def loggers(self) -> list[Logger] | list[FabricLogger]: """Reference to the list of loggers in the Trainer.""" if self._fabric is not None: return self._fabric.loggers @@ -315,7 +302,7 @@ def _on_before_batch_transfer(self, batch: Any, dataloader_idx: int = 0) -> Any: return self._call_batch_hook("on_before_batch_transfer", batch, dataloader_idx) def _apply_batch_transfer_handler( - self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 + self, batch: Any, device: torch.device | None = None, dataloader_idx: int = 0 ) -> Any: device = device or self.device batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx) @@ -346,16 +333,16 @@ def log( name: str, value: _METRIC, prog_bar: bool = False, - logger: Optional[bool] = None, - on_step: Optional[bool] = None, - on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = "mean", + logger: bool | None = None, + on_step: bool | None = None, + on_epoch: bool | None = None, + reduce_fx: str | Callable = "mean", enable_graph: bool = False, sync_dist: bool = False, - sync_dist_group: Optional[Any] = None, + sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, - batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, + batch_size: int | None = None, + metric_attribute: str | None = None, rank_zero_only: bool = False, ) -> None: """Log a key, value pair. @@ -508,15 +495,15 @@ def log_dict( self, dictionary: Mapping[str, _METRIC], prog_bar: bool = False, - logger: Optional[bool] = None, - on_step: Optional[bool] = None, - on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = "mean", + logger: bool | None = None, + on_step: bool | None = None, + on_epoch: bool | None = None, + reduce_fx: str | Callable = "mean", enable_graph: bool = False, sync_dist: bool = False, - sync_dist_group: Optional[Any] = None, + sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, - batch_size: Optional[int] = None, + batch_size: int | None = None, rank_zero_only: bool = False, ) -> None: """Log a dictionary of values at once. @@ -555,7 +542,7 @@ def log_dict( if self._fabric is not None: return self._log_dict_through_fabric(dictionary=dictionary, logger=logger) - kwargs: Dict[str, bool] = {} + kwargs: dict[str, bool] = {} if isinstance(dictionary, MetricCollection): kwargs["keep_base"] = False @@ -580,7 +567,7 @@ def log_dict( ) return None - def _log_dict_through_fabric(self, dictionary: Mapping[str, Any], logger: Optional[bool] = None) -> None: + def _log_dict_through_fabric(self, dictionary: Mapping[str, Any], logger: bool | None = None) -> None: if logger is False: # Passing `logger=False` with Fabric does not make much sense because there is no other destination to # log to, but we support it in case the original code was written for Trainer use @@ -604,7 +591,7 @@ def __check_not_nested(value: dict, name: str) -> None: def __check_allowed(v: Any, name: str, value: Any) -> None: raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged") - def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor: + def __to_tensor(self, value: Tensor | numbers.Number, name: str) -> Tensor: value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device) if not torch.numel(value) == 1: raise ValueError( @@ -615,8 +602,8 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor return value def all_gather( - self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, Dict, List, Tuple]: + self, data: Tensor | dict | list | tuple, group: Any | None = None, sync_grads: bool = False + ) -> Tensor | dict | list | tuple: r"""Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes. Failing to do so will cause your program to stall forever. @@ -701,7 +688,7 @@ def training_step(self, batch, batch_idx): """ rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer") - def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT | None: r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -768,7 +755,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): the model goes back to training mode and gradients are enabled. """ - def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT | None: r"""Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy. @@ -874,7 +861,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): """ return self(batch) - def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: + def configure_callbacks(self) -> Sequence[Callback] | Callback: """Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called, the list or a callback returned here will be merged with the list of callbacks passed to the Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks @@ -1040,7 +1027,7 @@ def backward(self, loss): else: loss.backward(*args, **kwargs) - def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None: + def toggle_optimizer(self, optimizer: Optimizer | LightningOptimizer) -> None: """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup. @@ -1068,7 +1055,7 @@ def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> N param.requires_grad = param_requires_grad_state[param] self._param_requires_grad_state = param_requires_grad_state - def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None: + def untoggle_optimizer(self, optimizer: Optimizer | LightningOptimizer) -> None: """Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. Args: @@ -1086,8 +1073,8 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> def clip_gradients( self, optimizer: Optimizer, - gradient_clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[str] = None, + gradient_clip_val: int | float | None = None, + gradient_clip_algorithm: str | None = None, ) -> None: """Handles gradient clipping internally. @@ -1153,8 +1140,8 @@ def clip_gradients( def configure_gradient_clipping( self, optimizer: Optimizer, - gradient_clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[str] = None, + gradient_clip_val: int | float | None = None, + gradient_clip_algorithm: str | None = None, ) -> None: """Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`. @@ -1180,7 +1167,7 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm ) - def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Optional[Any]) -> None: + def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Any | None) -> None: r""" Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls each scheduler. @@ -1214,8 +1201,8 @@ def optimizer_step( self, epoch: int, batch_idx: int, - optimizer: Union[Optimizer, LightningOptimizer], - optimizer_closure: Optional[Callable[[], Any]] = None, + optimizer: Optimizer | LightningOptimizer, + optimizer_closure: Callable[[], Any] | None = None, ) -> None: r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls the optimizer. @@ -1306,7 +1293,7 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None: ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None: + def to_onnx(self, file_path: str | Path, input_sample: Any | None = None, **kwargs: Any) -> None: """Saves the model in ONNX format. Args: @@ -1352,11 +1339,11 @@ def forward(self, x): @torch.no_grad() def to_torchscript( self, - file_path: Optional[Union[str, Path]] = None, - method: Optional[str] = "script", - example_inputs: Optional[Any] = None, + file_path: str | Path | None = None, + method: str | None = "script", + example_inputs: Any | None = None, **kwargs: Any, - ) -> Union[ScriptModule, Dict[str, ScriptModule]]: + ) -> ScriptModule | dict[str, ScriptModule]: """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing, please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that @@ -1435,9 +1422,9 @@ def forward(self, x): @classmethod def load_from_checkpoint( cls, - checkpoint_path: Union[_PATH, IO], + checkpoint_path: _PATH | IO, map_location: _MAP_LOCATION_TYPE = None, - hparams_file: Optional[_PATH] = None, + hparams_file: _PATH | None = None, strict: bool = True, **kwargs: Any, ) -> Self: @@ -1522,7 +1509,7 @@ def load_from_checkpoint( ) return cast(Self, loaded) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = dict(self.__dict__) state["_trainer"] = None return state diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index ee1cd45661165..5cd3b3740a959 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import contextmanager from dataclasses import fields -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Generator from weakref import proxy import torch @@ -43,7 +45,7 @@ def __init__(self, optimizer: Optimizer): self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer - self._strategy: Optional[pl.strategies.Strategy] = None + self._strategy: pl.strategies.Strategy | None = None # to inject logic around the optimizer step, particularly useful with manual optimization self._on_before_step = do_nothing_closure self._on_after_step = do_nothing_closure @@ -54,8 +56,8 @@ def optimizer(self) -> Optimizer: @classmethod def _to_lightning_optimizer( - cls, optimizer: Union[Optimizer, "LightningOptimizer"], strategy: "pl.strategies.Strategy" - ) -> "LightningOptimizer": + cls, optimizer: Optimizer | LightningOptimizer, strategy: pl.strategies.Strategy + ) -> LightningOptimizer: # the user could return a `LightningOptimizer` from `configure_optimizers`, see test: # tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False] lightning_optimizer = optimizer if isinstance(optimizer, LightningOptimizer) else cls(optimizer) @@ -84,7 +86,7 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: yield lightning_module.untoggle_optimizer(self) - def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any: + def step(self, closure: Callable[[], Any] | None = None, **kwargs: Any) -> Any: """Performs a single optimization step (parameter update). Args: @@ -160,8 +162,8 @@ def closure_dis(): def _init_optimizers_and_lr_schedulers( - model: "pl.LightningModule", -) -> Tuple[List[Optimizer], List[LRSchedulerConfig]]: + model: pl.LightningModule, +) -> tuple[list[Optimizer], list[LRSchedulerConfig]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" from lightning.pytorch.trainer import call @@ -185,9 +187,7 @@ def _init_optimizers_and_lr_schedulers( return optimizers, lr_scheduler_configs -def _configure_optimizers( - optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple] -) -> Tuple[List, List, Optional[str]]: +def _configure_optimizers(optim_conf: dict[str, Any] | list | Optimizer | tuple) -> tuple[list, list, str | None]: optimizers, lr_schedulers = [], [] monitor = None @@ -235,7 +235,7 @@ def _configure_optimizers( return optimizers, lr_schedulers, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: str | None) -> list[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization.""" lr_scheduler_configs = [] @@ -291,7 +291,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] return lr_scheduler_configs -def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]: +def _configure_schedulers_manual_opt(schedulers: list) -> list[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual optimization.""" lr_scheduler_configs = [] @@ -316,7 +316,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig return lr_scheduler_configs -def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: +def _validate_scheduler_api(lr_scheduler_configs: list[LRSchedulerConfig], model: pl.LightningModule) -> None: for config in lr_scheduler_configs: scheduler = config.scheduler if not isinstance(scheduler, _Stateful): @@ -333,7 +333,7 @@ def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model ) -def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "pl.LightningModule") -> None: +def _validate_multiple_optimizers_support(optimizers: list[Optimizer], model: pl.LightningModule) -> None: if model.automatic_optimization and len(optimizers) > 1: raise RuntimeError( "Training with multiple optimizers is only supported with manual optimization. Set" @@ -342,7 +342,7 @@ def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "p ) -def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: +def _validate_optimizers_attached(optimizers: list[Optimizer], lr_scheduler_configs: list[LRSchedulerConfig]) -> None: for config in lr_scheduler_configs: if config.scheduler.optimizer not in optimizers: raise MisconfigurationException( @@ -350,7 +350,7 @@ def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_conf ) -def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: +def _validate_optim_conf(optim_conf: dict[str, Any]) -> None: valid_keys = {"optimizer", "lr_scheduler", "monitor"} extra_keys = optim_conf.keys() - valid_keys if extra_keys: @@ -366,20 +366,20 @@ class _MockOptimizer(Optimizer): def __init__(self) -> None: super().__init__([torch.zeros(1)], {}) - def add_param_group(self, param_group: Dict[Any, Any]) -> None: + def add_param_group(self, param_group: dict[Any, Any]) -> None: pass # Do Nothing - def load_state_dict(self, state_dict: Dict[Any, Any]) -> None: + def load_state_dict(self, state_dict: dict[Any, Any]) -> None: pass # Do Nothing - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} # Return Empty - def step(self, closure: Optional[Callable] = None) -> None: + def step(self, closure: Callable | None = None) -> None: if closure is not None: closure() - def zero_grad(self, set_to_none: Optional[bool] = True) -> None: + def zero_grad(self, set_to_none: bool | None = True) -> None: pass # Do Nothing def __repr__(self) -> str: diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index a62a87227cf9d..9a56ea27f53ce 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import ast import contextlib import csv @@ -22,7 +24,7 @@ from copy import deepcopy from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, IO, Optional, Type, Union +from typing import Any, Callable, IO from warnings import warn import yaml @@ -50,13 +52,13 @@ def _load_from_checkpoint( - cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], - checkpoint_path: Union[_PATH, IO], + cls: type[pl.LightningModule] | type[pl.LightningDataModule], + checkpoint_path: _PATH | IO, map_location: _MAP_LOCATION_TYPE = None, - hparams_file: Optional[_PATH] = None, - strict: Optional[bool] = None, + hparams_file: _PATH | None = None, + strict: bool | None = None, **kwargs: Any, -) -> Union["pl.LightningModule", "pl.LightningDataModule"]: +) -> pl.LightningModule | pl.LightningDataModule: with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) @@ -98,11 +100,11 @@ def _load_from_checkpoint( def _load_state( - cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], - checkpoint: Dict[str, Any], - strict: Optional[bool] = None, + cls: type[pl.LightningModule] | type[pl.LightningDataModule], + checkpoint: dict[str, Any], + strict: bool | None = None, **cls_kwargs_new: Any, -) -> Union["pl.LightningModule", "pl.LightningDataModule"]: +) -> pl.LightningModule | pl.LightningDataModule: cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() @@ -166,9 +168,7 @@ def _load_state( return obj -def _convert_loaded_hparams( - model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None -) -> Dict[str, Any]: +def _convert_loaded_hparams(model_args: dict[str, Any], hparams_type: Callable | str | None = None) -> dict[str, Any]: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: @@ -209,7 +209,7 @@ def update_hparams(hparams: dict, updates: dict) -> None: hparams.update({k: v}) -def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]: +def load_hparams_from_tags_csv(tags_csv: _PATH) -> dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') @@ -230,7 +230,7 @@ def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]: return {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} -def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) -> None: +def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: dict | Namespace) -> None: fs = get_filesystem(tags_csv) if not fs.isdir(os.path.dirname(tags_csv)): raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") @@ -246,7 +246,7 @@ def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) - writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> dict[str, Any]: """Load hparams from a file. Args: @@ -276,7 +276,7 @@ def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Di return hparams -def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None: +def save_hparams_to_yaml(config_yaml: _PATH, hparams: dict | Namespace, use_omegaconf: bool = True) -> None: """ Args: config_yaml: path to new YAML file @@ -327,7 +327,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us yaml.dump(hparams_allowed, fp) -def convert(val: str) -> Union[int, float, bool, str]: +def convert(val: str) -> int | float | bool | str: try: return ast.literal_eval(val) except (ValueError, SyntaxError) as err: diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index 8a917b0670b63..42ee07f3fc1b5 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterator, List, Optional, Tuple +from __future__ import annotations + +from typing import Iterator import torch import torch.nn as nn @@ -35,7 +37,7 @@ def __init__(self, size: int, length: int): self.len = length self.data = torch.randn(length, size) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: a = self.data[index] b = a + 2 return {"a": a, "b": b} @@ -114,7 +116,7 @@ def __init__(self) -> None: def forward(self, x: Tensor) -> Tensor: return self.layer(x) - def loss(self, preds: Tensor, labels: Optional[Tensor] = None) -> Tensor: + def loss(self, preds: Tensor, labels: Tensor | None = None) -> Tensor: if labels is None: labels = torch.ones_like(preds) # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls @@ -127,13 +129,13 @@ def step(self, batch: Tensor) -> Tensor: def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: return {"loss": self.step(batch)} - def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: + def validation_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT | None: return {"x": self.step(batch)} - def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: + def test_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT | None: return {"y": self.step(batch)} - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]: + def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[_TORCH_LRSCHEDULER]]: optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 63c36e108a4b0..563896ecee359 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import random import time import urllib -from typing import Any, Callable, Optional, Sized, Tuple, Union +from typing import Any, Callable, Sized from urllib.error import HTTPError from warnings import warn @@ -65,7 +67,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + def __getitem__(self, idx: int) -> tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) @@ -101,7 +103,7 @@ def _download(self, data_folder: str) -> None: urllib.request.urlretrieve(url, fpath) # noqa: S310 @staticmethod - def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> Tuple[Tensor, Tensor]: + def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> tuple[Tensor, Tensor]: """Resolving loading from the same time from multiple concurrent processes.""" res, exception = None, None assert trials, "at least some trial has to be set" @@ -122,7 +124,7 @@ def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> Tuple[Ten return res @staticmethod - def normalize_tensor(tensor: Tensor, mean: Union[int, float] = 0.0, std: Union[int, float] = 1.0) -> Tensor: + def normalize_tensor(tensor: Tensor, mean: int | float = 0.0, std: int | float = 1.0) -> Tensor: mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) return tensor.sub(mean).div(std) @@ -240,7 +242,7 @@ def test_dataloader(self) -> DataLoader: ) @property - def default_transforms(self) -> Optional[Callable]: + def default_transforms(self) -> Callable | None: if not _TORCHVISION_AVAILABLE: return None if self.normalize: diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 0c328ba40188c..79fd8714445eb 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -3,10 +3,11 @@ Code is adapted from the PyTorch examples at https://github.com/pytorch/examples/blob/main/word_language_model """ +from __future__ import annotations + import math import os from pathlib import Path -from typing import Dict, List, Optional, Tuple import requests import torch @@ -38,7 +39,7 @@ def __init__( self.vocab_size = vocab_size self.src_mask = None - def forward(self, input: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, input: Tensor, target: Tensor, mask: Tensor | None = None) -> Tensor: b, t = input.shape # we assume target is already shifted w.r.t. input @@ -91,7 +92,7 @@ def vocab_size(self) -> int: def __len__(self) -> int: return len(self.data) // self.block_size - 1 - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: + def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: start = index * self.block_size end = start + self.block_size input = self.data[start:end] @@ -110,8 +111,8 @@ def download(destination: Path) -> None: class Dictionary: def __init__(self) -> None: - self.word2idx: Dict[str, int] = {} - self.idx2word: List[str] = [] + self.word2idx: dict[str, int] = {} + self.idx2word: list[str] = [] def add_word(self, word: str) -> int: if word not in self.word2idx: @@ -123,7 +124,7 @@ def __len__(self) -> int: return len(self.idx2word) -def tokenize(path: Path) -> Tuple[Tensor, Dictionary]: +def tokenize(path: Path) -> tuple[Tensor, Dictionary]: dictionary = Dictionary() assert os.path.exists(path) @@ -136,10 +137,10 @@ def tokenize(path: Path) -> Tuple[Tensor, Dictionary]: # Tokenize file content with open(path, encoding="utf8") as f: - idss: List[Tensor] = [] + idss: list[Tensor] = [] for line in f: words = line.split() + [""] - ids: List[int] = [] + ids: list[int] = [] for word in words: ids.append(dictionary.word2idx[word]) idss.append(torch.tensor(ids).type(torch.int64)) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 5dccdcd5c6698..5105250bb21ba 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -16,10 +16,12 @@ ------------ """ +from __future__ import annotations + import logging import os from argparse import Namespace -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Mapping from lightning_utilities.core.imports import module_available from torch import Tensor @@ -210,12 +212,12 @@ def __init__(self, *args, **kwarg): def __init__( self, - api_key: Optional[str] = None, - save_dir: Optional[str] = None, - project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, - experiment_name: Optional[str] = None, - experiment_key: Optional[str] = None, + api_key: str | None = None, + save_dir: str | None = None, + project_name: str | None = None, + rest_api_key: str | None = None, + experiment_name: str | None = None, + experiment_key: str | None = None, offline: bool = False, prefix: str = "", **kwargs: Any, @@ -226,8 +228,8 @@ def __init__( ) super().__init__() self._experiment = None - self._save_dir: Optional[str] - self.rest_api_key: Optional[str] + self._save_dir: str | None + self.rest_api_key: str | None # Determine online or offline mode based on which arguments were passed to CometLogger api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) @@ -249,12 +251,12 @@ def __init__( log.info(f"CometLogger will be initialized in {self.mode} mode") - self._project_name: Optional[str] = project_name - self._experiment_key: Optional[str] = experiment_key - self._experiment_name: Optional[str] = experiment_name + self._project_name: str | None = project_name + self._experiment_key: str | None = experiment_key + self._experiment_name: str | None = experiment_name self._prefix: str = prefix self._kwargs: Any = kwargs - self._future_experiment_key: Optional[str] = None + self._future_experiment_key: str | None = None if rest_api_key is not None: # Comet.ml rest API, used to determine version number @@ -266,7 +268,7 @@ def __init__( @property @rank_zero_experiment - def experiment(self) -> Union[CometExperiment, CometExistingExperiment, CometOfflineExperiment]: + def experiment(self) -> CometExperiment | CometExistingExperiment | CometOfflineExperiment: r""" Actual Comet object. To use Comet features in your :class:`~lightning.pytorch.core.module.LightningModule` do the following. @@ -312,13 +314,13 @@ def experiment(self) -> Union[CometExperiment, CometExistingExperiment, CometOff return self._experiment @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _convert_params(params) params = _flatten_dict(params) self.experiment.log_parameters(params) @rank_zero_only - def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, Tensor | float], step: int | None = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" # Comet.ml expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() @@ -352,7 +354,7 @@ def finalize(self, status: str) -> None: self.reset_experiment() @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Gets the save directory. Returns: @@ -408,7 +410,7 @@ def version(self) -> str: return self._future_experiment_key - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Save the experiment id in case an experiment object already exists, @@ -422,6 +424,6 @@ def __getstate__(self) -> Dict[str, Any]: state["_experiment"] = None return state - def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: + def log_graph(self, model: Module, input_array: Tensor | None = None) -> None: if self._experiment is not None: self._experiment.set_model_graph(model) diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 0f24eda761179..a77ad5019cc82 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -18,10 +18,12 @@ CSV logger for basic experiment logging that does not require opening ports """ +from __future__ import annotations + import logging import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger @@ -51,9 +53,9 @@ class ExperimentWriter(_FabricExperimentWriter): def __init__(self, log_dir: str) -> None: super().__init__(log_dir=log_dir) - self.hparams: Dict[str, Any] = {} + self.hparams: dict[str, Any] = {} - def log_hparams(self, params: Dict[str, Any]) -> None: + def log_hparams(self, params: dict[str, Any]) -> None: """Record hparams.""" self.hparams.update(params) @@ -90,7 +92,7 @@ def __init__( self, save_dir: _PATH, name: str = "lightning_logs", - version: Optional[Union[int, str]] = None, + version: int | str | None = None, prefix: str = "", flush_logs_every_n_steps: int = 100, ): @@ -133,7 +135,7 @@ def save_dir(self) -> str: return self._save_dir @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _convert_params(params) self.experiment.log_hparams(params) diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 52a51ab8eb74a..06384ba196a2a 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -14,11 +14,13 @@ """Abstract base class used to build new loggers.""" +from __future__ import annotations + import functools import operator from abc import ABC from collections import defaultdict -from typing import Any, Callable, Dict, Mapping, Optional, Sequence +from typing import Any, Callable, Mapping, Sequence import numpy as np @@ -40,7 +42,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: pass @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Return the root directory where experiment logs get saved, or `None` if the logger does not save data locally.""" return None @@ -77,7 +79,7 @@ def version(self) -> str: """Return the experiment version.""" return "" - def __getitem__(self, idx: int) -> "DummyLogger": + def __getitem__(self, idx: int) -> DummyLogger: # enables self.logger[0].experiment.add_image(...) return self @@ -93,9 +95,9 @@ def method(*args: Any, **kwargs: Any) -> None: # TODO: this should have been deprecated def merge_dicts( # pragma: no cover dicts: Sequence[Mapping], - agg_key_funcs: Optional[Mapping] = None, + agg_key_funcs: Mapping | None = None, default_func: Callable[[Sequence[float]], float] = np.mean, -) -> Dict: +) -> dict: """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. @@ -131,7 +133,7 @@ def merge_dicts( # pragma: no cover """ agg_key_funcs = agg_key_funcs or {} keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) - d_out: Dict = defaultdict(dict) + d_out: dict = defaultdict(dict) for k in keys: fn = agg_key_funcs.get(k) values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None] diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index f1386909653ec..a24aa2d7c92e9 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -15,6 +15,8 @@ MLflow Logger ------------- """ +from __future__ import annotations + import logging import os import re @@ -22,7 +24,7 @@ from argparse import Namespace from pathlib import Path from time import time -from typing import Any, Dict, List, Literal, Mapping, Optional, Union +from typing import Any, Literal, Mapping import yaml from lightning_utilities.core.imports import RequirementCache @@ -56,7 +58,7 @@ from mlflow.tracking.context.registry import resolve_tags else: - def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]: + def resolve_tags(tags: dict | None = None) -> dict | None: """ Args: tags: A dictionary of tags to override. If specified, tags passed in this argument will @@ -137,14 +139,14 @@ def any_lightning_module_function_or_hook(self): def __init__( self, experiment_name: str = "lightning_logs", - run_name: Optional[str] = None, - tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), - tags: Optional[Dict[str, Any]] = None, - save_dir: Optional[str] = "./mlruns", + run_name: str | None = None, + tracking_uri: str | None = os.getenv("MLFLOW_TRACKING_URI"), + tags: dict[str, Any] | None = None, + save_dir: str | None = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", - artifact_location: Optional[str] = None, - run_id: Optional[str] = None, + artifact_location: str | None = None, + run_id: str | None = None, ): if not _MLFLOW_AVAILABLE: raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE)) @@ -153,14 +155,14 @@ def __init__( tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" self._experiment_name = experiment_name - self._experiment_id: Optional[str] = None + self._experiment_id: str | None = None self._tracking_uri = tracking_uri self._run_name = run_name self._run_id = run_id self.tags = tags self._log_model = log_model - self._logged_model_time: Dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._logged_model_time: dict[str, float] = {} + self._checkpoint_callback: ModelCheckpoint | None = None self._prefix = prefix self._artifact_location = artifact_location @@ -213,7 +215,7 @@ def experiment(self) -> MlflowClient: return self._mlflow_client @property - def run_id(self) -> Optional[str]: + def run_id(self) -> str | None: """Create the experiment if it does not exist to get the run id. Returns: @@ -223,7 +225,7 @@ def run_id(self) -> Optional[str]: return self._run_id @property - def experiment_id(self) -> Optional[str]: + def experiment_id(self) -> str | None: """Create the experiment if it does not exist to get the experiment id. Returns: @@ -233,7 +235,7 @@ def experiment_id(self) -> Optional[str]: return self._experiment_id @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _convert_params(params) params = _flatten_dict(params) @@ -246,11 +248,11 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) @rank_zero_only - def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - metrics_list: List[Metric] = [] + metrics_list: list[Metric] = [] timestamp_ms = int(time() * 1000) for k, v in metrics.items(): @@ -289,7 +291,7 @@ def finalize(self, status: str = "success") -> None: self.experiment.set_terminated(self.run_id, status) @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """The root file directory in which MLflow experiments are saved. Return: @@ -301,7 +303,7 @@ def save_dir(self) -> Optional[str]: return None @property - def name(self) -> Optional[str]: + def name(self) -> str | None: """Get the experiment id. Returns: @@ -310,7 +312,7 @@ def name(self) -> Optional[str]: return self.experiment_id @property - def version(self) -> Optional[str]: + def version(self) -> str | None: """Get the run id. Returns: diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 53a15f50c43f8..e0ab4f9c43543 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -15,15 +15,13 @@ Neptune Logger -------------- """ -__all__ = [ - "NeptuneLogger", -] +from __future__ import annotations import contextlib import logging import os from argparse import Namespace -from typing import Any, Dict, Generator, List, Optional, Set, Union +from typing import Any, Generator from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -35,6 +33,10 @@ from lightning.pytorch.utilities.model_summary import ModelSummary from lightning.pytorch.utilities.rank_zero import rank_zero_only +__all__ = [ + "NeptuneLogger", +] + # neptune is available with two names on PyPI : `neptune` and `neptune-client` _NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0") _NEPTUNE_CLIENT_AVAILABLE = RequirementCache("neptune-client") @@ -232,11 +234,11 @@ def any_lightning_module_function_or_hook(self): def __init__( self, *, # force users to call `NeptuneLogger` initializer with `kwargs` - api_key: Optional[str] = None, - project: Optional[str] = None, - name: Optional[str] = None, - run: Optional[Union["Run", "Handler"]] = None, - log_model_checkpoints: Optional[bool] = True, + api_key: str | None = None, + project: str | None = None, + name: str | None = None, + run: Run | Handler | None = None, + log_model_checkpoints: bool | None = True, prefix: str = "training", **neptune_run_kwargs: Any, ): @@ -252,7 +254,7 @@ def __init__( self._api_key = api_key self._run_instance = run self._neptune_run_kwargs = neptune_run_kwargs - self._run_short_id: Optional[str] = None + self._run_short_id: str | None = None if self._run_instance is not None: self._retrieve_run_data() @@ -280,8 +282,8 @@ def _retrieve_run_data(self) -> None: self._run_name = "offline-name" @property - def _neptune_init_args(self) -> Dict: - args: Dict = {} + def _neptune_init_args(self) -> dict: + args: dict = {} # Backward compatibility in case of previous version retrieval with contextlib.suppress(AttributeError): args = self._neptune_run_kwargs @@ -310,10 +312,10 @@ def _construct_path_with_prefix(self, *keys: str) -> str: @staticmethod def _verify_input_arguments( - api_key: Optional[str], - project: Optional[str], - name: Optional[str], - run: Optional[Union["Run", "Handler"]], + api_key: str | None, + project: str | None, + name: str | None, + run: Run | Handler | None, neptune_run_kwargs: dict, ) -> None: # check if user passed the client `Run`/`Handler` object @@ -328,13 +330,13 @@ def _verify_input_arguments( " you can't provide other neptune.init_run() parameters.\n" ) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Run instance can't be pickled state["_run_instance"] = None return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__ = state self._run_instance = neptune.init_run(**self._neptune_init_args) @@ -380,7 +382,7 @@ def run(self) -> Run: return self._run_instance @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # skipcq: PYL-W0221 + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: # skipcq: PYL-W0221 r"""Log hyperparameters to the run. Hyperparameters will be logged under the "/hyperparams" namespace. @@ -423,7 +425,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # self.run[parameters_key] = stringify_unsupported(params) @rank_zero_only - def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, Tensor | float], step: int | None = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -452,7 +454,7 @@ def finalize(self, status: str) -> None: super().finalize(status) @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save locally. @@ -462,7 +464,7 @@ def save_dir(self) -> Optional[str]: return os.path.join(os.getcwd(), ".neptune") @rank_zero_only - def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1) -> None: + def log_model_summary(self, model: pl.LightningModule, max_depth: int = -1) -> None: model_str = str(ModelSummary(model=model, max_depth=max_depth)) self.run[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content( content=model_str, extension="txt" @@ -531,16 +533,16 @@ def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> st return model_path @classmethod - def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]: + def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]: """Returns all paths to properties which were already logged in `namespace`""" - structure_keys: List[str] = namespace.split(cls.LOGGER_JOIN_CHAR) + structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR) for key in structure_keys: exp_structure = exp_structure[key] uploaded_models_dict = exp_structure return set(cls._dict_paths(uploaded_models_dict)) @classmethod - def _dict_paths(cls, d: Dict[str, Any], path_in_build: Optional[str] = None) -> Generator: + def _dict_paths(cls, d: dict[str, Any], path_in_build: str | None = None) -> Generator: for k, v in d.items(): path = f"{path_in_build}/{k}" if path_in_build is not None else k if not isinstance(v, dict): @@ -549,12 +551,12 @@ def _dict_paths(cls, d: Dict[str, Any], path_in_build: Optional[str] = None) -> yield from cls._dict_paths(v, path) @property - def name(self) -> Optional[str]: + def name(self) -> str | None: """Return the experiment name or 'offline-name' when exp is run in offline mode.""" return self._run_name @property - def version(self) -> Optional[str]: + def version(self) -> str | None: """Return the experiment version. It's Neptune Run's short_id diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index b14d1357e2e28..399f705557921 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -16,10 +16,12 @@ ------------------ """ +from __future__ import annotations + import logging import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any from torch import Tensor @@ -101,12 +103,12 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): def __init__( self, save_dir: _PATH, - name: Optional[str] = "lightning_logs", - version: Optional[Union[int, str]] = None, + name: str | None = "lightning_logs", + version: int | str | None = None, log_graph: bool = False, default_hp_metric: bool = True, prefix: str = "", - sub_dir: Optional[_PATH] = None, + sub_dir: _PATH | None = None, **kwargs: Any, ): super().__init__( @@ -124,7 +126,7 @@ def __init__( f"{str(_TENSORBOARD_AVAILABLE)}" ) self._log_graph = log_graph and _TENSORBOARD_AVAILABLE - self.hparams: Union[Dict[str, Any], Namespace] = {} + self.hparams: dict[str, Any] | Namespace = {} @property def root_dir(self) -> str: @@ -161,9 +163,7 @@ def save_dir(self) -> str: return self._root_dir @rank_zero_only - def log_hyperparams( - self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None - ) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace, metrics: dict[str, Any] | None = None) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to display the new ones with hyperparameters. @@ -183,7 +183,7 @@ def log_hyperparams( return super().log_hyperparams(params=params, metrics=metrics) @rank_zero_only - def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None: + def log_graph(self, model: pl.LightningModule, input_array: Tensor | None = None) -> None: if not self._log_graph: return diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index ddc9e24749318..6faa9b583cfe1 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -13,8 +13,10 @@ # limitations under the License. """Utilities for loggers.""" +from __future__ import annotations + from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import Any from torch import Tensor @@ -22,14 +24,14 @@ from lightning.pytorch.callbacks import Checkpoint -def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: +def _version(loggers: list[Any], separator: str = "_") -> int | str: if len(loggers) == 1: return loggers[0].version # Concatenate versions together, removing duplicates and preserving order return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) -def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> List[Tuple[float, str, float, str]]: +def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> list[tuple[float, str, float, str]]: """Return the checkpoints to be logged. Args: @@ -55,7 +57,7 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) return checkpoints -def _log_hyperparams(trainer: "pl.Trainer") -> None: +def _log_hyperparams(trainer: pl.Trainer) -> None: if not trainer.loggers: return diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index d5f928bd8e69d..ee822c79aa2c1 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -15,10 +15,12 @@ Weights and Biases Logger ------------------------- """ +from __future__ import annotations + import os from argparse import Namespace from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Mapping import torch.nn as nn from lightning_utilities.core.imports import RequirementCache @@ -287,18 +289,18 @@ def any_lightning_module_function_or_hook(self): def __init__( self, - name: Optional[str] = None, + name: str | None = None, save_dir: _PATH = ".", - version: Optional[str] = None, + version: str | None = None, offline: bool = False, - dir: Optional[_PATH] = None, - id: Optional[str] = None, - anonymous: Optional[bool] = None, - project: Optional[str] = None, - log_model: Union[str, bool] = False, - experiment: Union[Run, RunDisabled, None] = None, + dir: _PATH | None = None, + id: str | None = None, + anonymous: bool | None = None, + project: str | None = None, + log_model: str | bool = False, + experiment: Run | RunDisabled | None = None, prefix: str = "", - checkpoint_name: Optional[str] = None, + checkpoint_name: str | None = None, **kwargs: Any, ) -> None: if wandb is None: @@ -326,8 +328,8 @@ def __init__( self._log_model = log_model self._prefix = prefix self._experiment = experiment - self._logged_model_time: Dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._logged_model_time: dict[str, float] = {} + self._checkpoint_callback: ModelCheckpoint | None = None # paths are processed as strings if save_dir is not None: @@ -338,7 +340,7 @@ def __init__( project = project or os.environ.get("WANDB_PROJECT", "lightning_logs") # set wandb init arguments - self._wandb_init: Dict[str, Any] = { + self._wandb_init: dict[str, Any] = { "name": name, "project": project, "dir": save_dir or dir, @@ -354,7 +356,7 @@ def __init__( self._id = self._wandb_init.get("id") self._checkpoint_name = checkpoint_name - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called. # We create an experiment here in the main process, and attach to it in the worker process. # Using wandb-service, we persist the same experiment even if multiple `Trainer.fit/test/validate` calls @@ -376,7 +378,7 @@ def __getstate__(self) -> Dict[str, Any]: @property @rank_zero_experiment - def experiment(self) -> Union[Run, RunDisabled]: + def experiment(self) -> Run | RunDisabled: r""" Actual wandb object. To use wandb features in your @@ -422,13 +424,13 @@ def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, l self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: params = _convert_params(params) params = _sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) @rank_zero_only - def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: int | None = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) @@ -441,10 +443,10 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) def log_table( self, key: str, - columns: Optional[List[str]] = None, - data: Optional[List[List[Any]]] = None, + columns: list[str] | None = None, + data: list[list[Any]] | None = None, dataframe: Any = None, - step: Optional[int] = None, + step: int | None = None, ) -> None: """Log a Table containing any object type (text, image, audio, video, molecule, html, etc). @@ -458,10 +460,10 @@ def log_table( def log_text( self, key: str, - columns: Optional[List[str]] = None, - data: Optional[List[List[str]]] = None, + columns: list[str] | None = None, + data: list[list[str]] | None = None, dataframe: Any = None, - step: Optional[int] = None, + step: int | None = None, ) -> None: """Log text as a Table. @@ -471,7 +473,7 @@ def log_text( self.log_table(key, columns, data, dataframe, step) @rank_zero_only - def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_image(self, key: str, images: list[Any], step: int | None = None, **kwargs: Any) -> None: """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). @@ -487,7 +489,7 @@ def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str | None: """Gets the save directory. Returns: @@ -496,7 +498,7 @@ def save_dir(self) -> Optional[str]: return self._save_dir @property - def name(self) -> Optional[str]: + def name(self) -> str | None: """The project name of this experiment. Returns: @@ -506,7 +508,7 @@ def name(self) -> Optional[str]: return self._project @property - def version(self) -> Optional[str]: + def version(self) -> str | None: """Gets the id of the experiment. Returns: @@ -526,9 +528,9 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: @rank_zero_only def download_artifact( artifact: str, - save_dir: Optional[_PATH] = None, - artifact_type: Optional[str] = None, - use_artifact: Optional[bool] = True, + save_dir: _PATH | None = None, + artifact_type: str | None = None, + use_artifact: bool | None = True, ) -> str: """Downloads an artifact from the wandb server. @@ -550,7 +552,7 @@ def download_artifact( save_dir = None if save_dir is None else os.fspath(save_dir) return artifact.download(root=save_dir) - def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "wandb.Artifact": + def use_artifact(self, artifact: str, artifact_type: str | None = None) -> wandb.Artifact: """Logs to the wandb dashboard that the mentioned artifact is used by the run. Args: diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 14f39e5c0fa08..0b64a802f56bb 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import shutil import sys from collections import ChainMap, defaultdict, OrderedDict -from typing import Any, DefaultDict, Iterable, List, Optional, Tuple, Union +from typing import Any, DefaultDict, Iterable from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -53,7 +55,7 @@ class _EvaluationLoop(_Loop): def __init__( self, - trainer: "pl.Trainer", + trainer: pl.Trainer, trainer_fn: TrainerFn, stage: RunningStage, verbose: bool = True, @@ -64,16 +66,16 @@ def __init__( self.inference_mode = inference_mode self.batch_progress = _BatchProgress() # across dataloaders # list in "sequential" mode, number otherwise - self._max_batches: Union[int, float, List[Union[int, float]]] = [] + self._max_batches: int | float | list[int | float] = [] self._results = _ResultCollection(training=False) - self._logged_outputs: List[_OUT_DICT] = [] + self._logged_outputs: list[_OUT_DICT] = [] self._has_run: bool = False self._trainer_fn = trainer_fn self._stage = stage self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader") - self._combined_loader: Optional[CombinedLoader] = None - self._data_fetcher: Optional[_DataFetcher] = None + self._combined_loader: CombinedLoader | None = None + self._data_fetcher: _DataFetcher | None = None self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) self._last_val_dl_reload_epoch = float("-inf") @@ -85,7 +87,7 @@ def num_dataloaders(self) -> int: return len(combined_loader.flattened) @property - def max_batches(self) -> Union[int, float, List[Union[int, float]]]: + def max_batches(self) -> int | float | list[int | float]: """In "sequential" mode, the max number of batches to run per dataloader. Otherwise, the max batches to run. @@ -115,7 +117,7 @@ def _is_sequential(self) -> bool: return self._combined_loader._mode == "sequential" @_no_grad_context - def run(self) -> List[_OUT_DICT]: + def run(self) -> list[_OUT_DICT]: self.setup_data() if self.skip: return [] @@ -268,7 +270,7 @@ def on_run_start(self) -> None: self._on_evaluation_start() self._on_evaluation_epoch_start() - def on_run_end(self) -> List[_OUT_DICT]: + def on_run_end(self) -> list[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` self.trainer._logger_connector.epoch_end_reached() @@ -421,7 +423,7 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> N if not self.batch_progress.is_last_batch and trainer.received_sigterm: raise SIGTERMException - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> OrderedDict: + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int | None) -> OrderedDict: """Helper method to build the arguments for the current step. Args: @@ -451,7 +453,7 @@ def _verify_dataloader_idx_requirement(self) -> None: ) @staticmethod - def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: + def _get_keys(data: dict) -> Iterable[tuple[str, ...]]: for k, v in data.items(): if isinstance(v, dict): for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys): @@ -460,7 +462,7 @@ def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: yield k, @staticmethod - def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: + def _find_value(data: dict, target: Iterable[str]) -> Any | None: target_start, *rest = target if target_start not in data: return None @@ -470,7 +472,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: return _EvaluationLoop._find_value(result, rest) @staticmethod - def _print_results(results: List[_OUT_DICT], stage: str) -> None: + def _print_results(results: list[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys} @@ -487,7 +489,7 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None: term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2)) - rows: List[List[Any]] = [[] for _ in metrics_paths] + rows: list[list[Any]] = [[] for _ in metrics_paths] for result in results: for metric, row in zip(metrics_paths, rows): diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 8df73c891c920..0079a37bdb13d 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, List, Optional, Tuple, Union +from __future__ import annotations + +from typing import Any, Iterator from lightning.fabric.utilities.data import sized_len from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader @@ -25,8 +27,8 @@ def _profile_nothing() -> None: class _DataFetcher(Iterator): def __init__(self) -> None: - self._combined_loader: Optional[CombinedLoader] = None - self.iterator: Optional[Iterator] = None + self._combined_loader: CombinedLoader | None = None + self.iterator: Iterator | None = None self.fetched: int = 0 self.done: bool = False self._start_profiler = _profile_nothing @@ -43,7 +45,7 @@ def combined_loader(self) -> CombinedLoader: def setup(self, combined_loader: CombinedLoader) -> None: self._combined_loader = combined_loader - def __iter__(self) -> "_DataFetcher": + def __iter__(self) -> _DataFetcher: self.reset() self.iterator = iter(self.combined_loader) return self @@ -85,14 +87,14 @@ def __init__(self, prefetch_batches: int = 1) -> None: if prefetch_batches < 0: raise ValueError("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches - self.batches: List[Any] = [] - self._len: Optional[int] = None + self.batches: list[Any] = [] + self._len: int | None = None def setup(self, combined_loader: CombinedLoader) -> None: super().setup(combined_loader) self._len = sized_len(combined_loader) - def __iter__(self) -> "_PrefetchDataFetcher": + def __iter__(self) -> _PrefetchDataFetcher: super().__iter__() if self._len is not None: # ignore pre-fetching, it's not necessary @@ -166,12 +168,12 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: ... """ - def __iter__(self) -> "_DataLoaderIterDataFetcher": + def __iter__(self) -> _DataLoaderIterDataFetcher: super().__iter__() self.iterator_wrapper = iter(_DataFetcherWrapper(self)) return self - def __next__(self) -> Union["_DataFetcherWrapper", Tuple["_DataFetcherWrapper", int, int]]: + def __next__(self) -> _DataFetcherWrapper | tuple[_DataFetcherWrapper, int, int]: if self.done: raise StopIteration assert isinstance(self.iterator_wrapper, _DataFetcherWrapper) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 16c11ba45c677..0525f68679549 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging -from typing import Optional, Union import torch @@ -73,9 +74,9 @@ class _FitLoop(_Loop): def __init__( self, - trainer: "pl.Trainer", - min_epochs: Optional[int] = 0, - max_epochs: Optional[int] = None, + trainer: pl.Trainer, + min_epochs: int | None = 0, + max_epochs: int | None = None, ) -> None: super().__init__(trainer) if isinstance(max_epochs, int) and max_epochs < -1: @@ -88,11 +89,11 @@ def __init__( self.min_epochs = min_epochs self.epoch_loop = _TrainingEpochLoop(trainer) self.epoch_progress = _Progress() - self.max_batches: Union[int, float] = float("inf") + self.max_batches: int | float = float("inf") self._data_source = _DataLoaderSource(None, "train_dataloader") - self._combined_loader: Optional[CombinedLoader] = None - self._data_fetcher: Optional[_DataFetcher] = None + self._combined_loader: CombinedLoader | None = None + self._data_fetcher: _DataFetcher | None = None self._last_train_dl_reload_epoch = float("-inf") @property @@ -106,7 +107,7 @@ def batch_idx(self) -> int: return self.epoch_loop.batch_idx @property - def min_steps(self) -> Optional[int]: + def min_steps(self) -> int | None: """Returns the minimum number of steps to run.""" return self.epoch_loop.min_steps diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index 2a3bf1dfc4a9b..2a8857b92efbc 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from __future__ import annotations import lightning.pytorch as pl from lightning.pytorch.loops.progress import _BaseProgress @@ -20,7 +20,7 @@ class _Loop: """Basic Loops interface.""" - def __init__(self, trainer: "pl.Trainer") -> None: + def __init__(self, trainer: pl.Trainer) -> None: self._restarting = False self.trainer = trainer @@ -37,7 +37,7 @@ def restarting(self, restarting: bool) -> None: if isinstance(loop, _Loop): loop.restarting = restarting - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: """Called when saving a model checkpoint, use to persist loop state. Returns: @@ -45,10 +45,10 @@ def on_save_checkpoint(self) -> Dict: """ return {} - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" - def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict: + def state_dict(self, destination: dict | None = None, prefix: str = "") -> dict: """The state dict is determined by the state and progress of this loop and all its children. Args: @@ -71,7 +71,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di def load_state_dict( self, - state_dict: Dict, + state_dict: dict, prefix: str = "", ) -> None: """Loads the state of this loop and all its children.""" @@ -81,7 +81,7 @@ def load_state_dict( v.load_state_dict(state_dict.copy(), prefix + k + ".") self.restarting = True - def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: + def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: for k, v in self.__dict__.items(): key = prefix + k if key not in state_dict: diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index ea350464d9b77..de4a1ac0395c6 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, Optional, OrderedDict +from typing import Any, Callable, Dict, OrderedDict import torch from torch import Tensor @@ -42,9 +44,9 @@ class ClosureResult(OutputResult): extra: Any keys other than the loss returned. """ - closure_loss: Optional[Tensor] - loss: Optional[Tensor] = field(init=False, default=None) - extra: Dict[str, Any] = field(default_factory=dict) + closure_loss: Tensor | None + loss: Tensor | None = field(init=False, default=None) + extra: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: self._clone_loss() @@ -55,9 +57,7 @@ def _clone_loss(self) -> None: self.loss = self.closure_loss.detach().clone() @classmethod - def from_training_step_output( - cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1 - ) -> "ClosureResult": + def from_training_step_output(cls, training_step_output: STEP_OUTPUT | None, normalize: int = 1) -> ClosureResult: closure_loss, extra = None, {} if isinstance(training_step_output, dict): @@ -82,7 +82,7 @@ def from_training_step_output( return cls(closure_loss, extra=extra) - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: return {"loss": self.loss, **self.extra} @@ -114,8 +114,8 @@ class Closure(AbstractClosure[ClosureResult]): def __init__( self, step_fn: Callable[[], ClosureResult], - backward_fn: Optional[Callable[[Tensor], None]] = None, - zero_grad_fn: Optional[Callable[[], None]] = None, + backward_fn: Callable[[Tensor], None] | None = None, + zero_grad_fn: Callable[[], None] | None = None, ): super().__init__() self._step_fn = step_fn @@ -136,7 +136,7 @@ def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: return step_output - def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: + def __call__(self, *args: Any, **kwargs: Any) -> Tensor | None: self._result = self.closure(*args, **kwargs) return self._result.loss @@ -149,7 +149,7 @@ class _AutomaticOptimization(_Loop): output_result_cls = ClosureResult - def __init__(self, trainer: "pl.Trainer") -> None: + def __init__(self, trainer: pl.Trainer) -> None: super().__init__(trainer) self.optim_progress: _OptimizationProgress = _OptimizationProgress() self._skip_backward: bool = False @@ -201,7 +201,7 @@ def _make_step_fn(self, kwargs: OrderedDict) -> Callable[[], ClosureResult]: """Build the step function that runs the `training_step` and processes its output.""" return partial(self._training_step, kwargs) - def _make_zero_grad_fn(self, batch_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: + def _make_zero_grad_fn(self, batch_idx: int, optimizer: Optimizer) -> Callable[[], None] | None: """Build a `zero_grad` function that zeroes the gradients before back-propagation. Returns ``None`` in the case backward needs to be skipped. @@ -219,7 +219,7 @@ def zero_grad_fn() -> None: return zero_grad_fn - def _make_backward_fn(self, optimizer: Optimizer) -> Optional[Callable[[Tensor], None]]: + def _make_backward_fn(self, optimizer: Optimizer) -> Callable[[Tensor], None] | None: """Build a `backward` function that handles back-propagation through the output produced by the `training_step` function. @@ -236,7 +236,7 @@ def backward_fn(loss: Tensor) -> None: def _optimizer_step( self, batch_idx: int, - train_step_and_backward_closure: Callable[[], Optional[Tensor]], + train_step_and_backward_closure: Callable[[], Tensor | None], ) -> None: """Performs the optimizer step and some sanity checking. diff --git a/src/lightning/pytorch/loops/optimization/closure.py b/src/lightning/pytorch/loops/optimization/closure.py index ec85a96e54042..ec96a083a7039 100644 --- a/src/lightning/pytorch/loops/optimization/closure.py +++ b/src/lightning/pytorch/loops/optimization/closure.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -22,7 +24,7 @@ @dataclass class OutputResult: - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: raise NotImplementedError @@ -39,7 +41,7 @@ class AbstractClosure(ABC, Generic[T]): def __init__(self) -> None: super().__init__() - self._result: Optional[T] = None + self._result: T | None = None def consume_result(self) -> T: """The cached result from the last time the closure was called. diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index 9cc77793591b0..8085e24b04085 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import OrderedDict from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Any, Dict from torch import Tensor @@ -38,10 +40,10 @@ class ManualResult(OutputResult): extra: Anything returned by the ``training_step``. """ - extra: Dict[str, Any] = field(default_factory=dict) + extra: dict[str, Any] = field(default_factory=dict) @classmethod - def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ManualResult": + def from_training_step_output(cls, training_step_output: STEP_OUTPUT | None) -> ManualResult: extra = {} if isinstance(training_step_output, dict): extra = training_step_output.copy() @@ -58,7 +60,7 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) return cls(extra=extra) - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: return self.extra @@ -76,7 +78,7 @@ class _ManualOptimization(_Loop): output_result_cls = ManualResult - def __init__(self, trainer: "pl.Trainer") -> None: + def __init__(self, trainer: pl.Trainer) -> None: super().__init__(trainer) # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than # `_OptimizationProgress` diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 227d3246bba3c..b6c389851d0d4 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections import OrderedDict -from typing import Any, Dict, List, Optional, Union +from typing import Any import torch from lightning_utilities import WarningCache @@ -44,21 +46,21 @@ class _PredictionLoop(_Loop): """Top-level loop where prediction starts.""" - def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: + def __init__(self, trainer: pl.Trainer, inference_mode: bool = True) -> None: super().__init__(trainer) self.inference_mode = inference_mode # dataloaders x batches x samples. used by PredictionWriter - self.epoch_batch_indices: List[List[List[int]]] = [] - self.current_batch_indices: List[int] = [] # used by PredictionWriter + self.epoch_batch_indices: list[list[list[int]]] = [] + self.current_batch_indices: list[int] = [] # used by PredictionWriter self.batch_progress = _Progress() # across dataloaders - self.max_batches: List[Union[int, float]] = [] + self.max_batches: list[int | float] = [] self._warning_cache = WarningCache() self._data_source = _DataLoaderSource(None, "predict_dataloader") - self._combined_loader: Optional[CombinedLoader] = None - self._data_fetcher: Optional[_DataFetcher] = None + self._combined_loader: CombinedLoader | None = None + self._data_fetcher: _DataFetcher | None = None self._results = None # for `trainer._results` access - self._predictions: List[List[Any]] = [] # dataloaders x batches + self._predictions: list[list[Any]] = [] # dataloaders x batches self._return_predictions = False @property @@ -67,7 +69,7 @@ def return_predictions(self) -> bool: return self._return_predictions @return_predictions.setter - def return_predictions(self, return_predictions: Optional[bool] = None) -> None: + def return_predictions(self, return_predictions: bool | None = None) -> None: # Strategies that spawn or fork don't support returning predictions return_supported = not isinstance(self.trainer.strategy.launcher, _MultiProcessingLauncher) if return_predictions and not return_supported: @@ -79,7 +81,7 @@ def return_predictions(self, return_predictions: Optional[bool] = None) -> None: self._return_predictions = return_supported if return_predictions is None else return_predictions @property - def predictions(self) -> List[Any]: + def predictions(self) -> list[Any]: """The cached predictions.""" if self._predictions == []: return self._predictions @@ -97,7 +99,7 @@ def skip(self) -> bool: return sum(self.max_batches) == 0 @_no_grad_context - def run(self) -> Optional[_PREDICT_OUTPUT]: + def run(self) -> _PREDICT_OUTPUT | None: self.setup_data() if self.skip: return None @@ -190,7 +192,7 @@ def on_run_start(self) -> None: self._on_predict_start() self._on_predict_epoch_start() - def on_run_end(self) -> Optional[_PREDICT_OUTPUT]: + def on_run_end(self) -> _PREDICT_OUTPUT | None: """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" results = self._on_predict_epoch_end() self._on_predict_end() @@ -240,7 +242,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None if self._return_predictions or any_on_epoch: self._predictions[dataloader_idx].append(move_data_to_device(predictions, torch.device("cpu"))) - def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> Dict[str, Any]: + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int | None) -> dict[str, Any]: """Assembles the keyword arguments for the ``predict_step`` Args: @@ -256,7 +258,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs - def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples + def _get_batch_indices(self, dataloader: object) -> list[list[int]]: # batches x samples """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`.""" batch_sampler = getattr(dataloader, "batch_sampler", None) @@ -308,7 +310,7 @@ def _on_predict_epoch_start(self) -> None: call._call_callback_hooks(trainer, "on_predict_epoch_start") call._call_lightning_module_hook(trainer, "on_predict_epoch_start") - def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: + def _on_predict_epoch_end(self) -> _PREDICT_OUTPUT | None: """Calls ``on_predict_epoch_end`` hook. Returns: diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index d2e52f44d7ba7..607ee6f4a7480 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import asdict, dataclass, field -from typing import Type @dataclass @@ -26,7 +27,7 @@ def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) @classmethod - def from_state_dict(cls, state_dict: dict) -> "_BaseProgress": + def from_state_dict(cls, state_dict: dict) -> _BaseProgress: obj = cls() obj.load_state_dict(state_dict) return obj @@ -148,7 +149,7 @@ def increment_completed(self) -> None: self.current.completed += 1 @classmethod - def from_defaults(cls, tracker_cls: Type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": + def from_defaults(cls, tracker_cls: type[_ReadyCompletedTracker], **kwargs: int) -> _Progress: """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 208faf8f16d05..abc6909e14836 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math from collections import OrderedDict -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union import lightning.pytorch as pl from lightning.pytorch import loops # import as loops to avoid circular imports @@ -55,7 +57,7 @@ class _TrainingEpochLoop(loops._Loop): max_steps: The maximum number of steps (batches) to process """ - def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_steps: int = -1) -> None: + def __init__(self, trainer: pl.Trainer, min_steps: int | None = None, max_steps: int = -1) -> None: super().__init__(trainer) if max_steps < -1: raise MisconfigurationException( @@ -270,12 +272,12 @@ def teardown(self) -> None: self._results.cpu() self.val_loop.teardown() - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: state_dict = super().on_save_checkpoint() state_dict["_batches_that_stepped"] = self._batches_that_stepped return state_dict - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) def _accumulated_batches_reached(self) -> bool: @@ -359,7 +361,7 @@ def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool) ) self.scheduler_progress.increment_completed() - def _get_monitor_value(self, key: str) -> Optional[Any]: + def _get_monitor_value(self, key: str) -> Any | None: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index b449355582c46..548e5c9000a73 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect from contextlib import contextmanager -from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type +from typing import Any, Callable, ContextManager, Generator import torch import torch.distributed as dist @@ -35,7 +37,7 @@ from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature -def check_finite_loss(loss: Optional[Tensor]) -> None: +def check_finite_loss(loss: Tensor | None) -> None: """Checks for finite loss value. Args: @@ -46,12 +48,12 @@ def check_finite_loss(loss: Optional[Tensor]) -> None: def _parse_loop_limits( - min_steps: Optional[int], + min_steps: int | None, max_steps: int, - min_epochs: Optional[int], - max_epochs: Optional[int], - trainer: "pl.Trainer", -) -> Tuple[int, int]: + min_epochs: int | None, + max_epochs: int | None, + trainer: pl.Trainer, +) -> tuple[int, int]: """This utility computes the default values for the minimum and maximum number of steps and epochs given the values the user has selected. @@ -127,7 +129,7 @@ def _reset_progress(loop: _Loop) -> None: _reset_progress(v) -def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher: +def _select_data_fetcher(trainer: pl.Trainer) -> _DataFetcher: lightning_module = trainer.lightning_module if trainer.testing: step_fx_name = "test_step" @@ -155,7 +157,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: raise TypeError(f"`{type(self).__name__}` needs to be a Loop.") if not hasattr(self, "inference_mode"): raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined") - context_manager: Type[ContextManager] + context_manager: type[ContextManager] if dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo": # noqa: SIM114 # gloo backend does not work properly. # https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110 @@ -180,7 +182,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: def _verify_dataloader_idx_requirement( - hooks: Tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" + hooks: tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: pl.LightningModule ) -> None: for hook in hooks: fx = getattr(pl_module, hook) diff --git a/src/lightning/pytorch/overrides/base.py b/src/lightning/pytorch/overrides/base.py index 8a5282cb0334b..f2e8e2b77b1d7 100644 --- a/src/lightning/pytorch/overrides/base.py +++ b/src/lightning/pytorch/overrides/base.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Any import torch @@ -20,7 +22,7 @@ class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: "pl.LightningModule") -> None: + def __init__(self, pl_module: pl.LightningModule) -> None: """Wraps the user's LightningModule. Requires overriding all ``*_step`` methods and ``forward`` so that it can safely be wrapped by ``*DataParallel``. diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 1480163dc57c0..1c5372873ff27 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import itertools -from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union +from typing import Any, Callable, cast, Iterable, Iterator, Sized import torch from torch import Tensor @@ -24,9 +26,7 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info -def _find_tensors( - obj: Union[Tensor, list, tuple, dict, Any] -) -> Union[List[Tensor], itertools.chain]: # pragma: no-cover +def _find_tensors(obj: Tensor | list | tuple | dict | Any) -> list[Tensor] | itertools.chain: # pragma: no-cover """Recursively find all tensors contained in the specified object.""" if isinstance(obj, Tensor): return [obj] @@ -59,9 +59,9 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None: def _register_ddp_comm_hook( model: DistributedDataParallel, - ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[Callable] = None, - ddp_comm_wrapper: Optional[Callable] = None, + ddp_comm_state: object | None = None, + ddp_comm_hook: Callable | None = None, + ddp_comm_wrapper: Callable | None = None, ) -> None: """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. @@ -212,7 +212,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # have at least one batch, or the DistributedDataParallel could lock up. assert self.num_samples >= 1 or self.total_size == 0 - def __iter__(self) -> Iterator[List[int]]: + def __iter__(self) -> Iterator[list[int]]: if not isinstance(self.dataset, Sized): raise TypeError("The given dataset must implement the `__len__` method.") if self.shuffle: @@ -235,7 +235,7 @@ def __iter__(self) -> Iterator[List[int]]: class UnrepeatedDistributedSamplerWrapper(UnrepeatedDistributedSampler): """Equivalent class to ``DistributedSamplerWrapper`` but for the ``UnrepeatedDistributedSampler``.""" - def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: + def __init__(self, sampler: Sampler | Iterable, *args: Any, **kwargs: Any) -> None: super().__init__(_DatasetSamplerWrapper(sampler), *args, **kwargs) def __iter__(self) -> Iterator: @@ -248,7 +248,7 @@ class _IndexBatchSamplerWrapper(BatchSampler): def __init__(self, batch_sampler: BatchSampler) -> None: # do not call super().__init__() on purpose - self.seen_batch_indices: List[List[int]] = [] + self.seen_batch_indices: list[list[int]] = [] self.__dict__ = { k: v @@ -256,15 +256,15 @@ def __init__(self, batch_sampler: BatchSampler) -> None: if k not in ("__next__", "__iter__", "__len__", "__getstate__") } self._batch_sampler = batch_sampler - self._iterator: Optional[Iterator[List[int]]] = None + self._iterator: Iterator[list[int]] | None = None - def __next__(self) -> List[int]: + def __next__(self) -> list[int]: assert self._iterator is not None batch = next(self._iterator) self.seen_batch_indices.append(batch) return batch - def __iter__(self) -> Iterator[List[int]]: + def __iter__(self) -> Iterator[list[int]]: self.seen_batch_indices = [] self._iterator = iter(self._batch_sampler) return self @@ -272,7 +272,7 @@ def __iter__(self) -> Iterator[List[int]]: def __len__(self) -> int: return len(self._batch_sampler) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state["_iterator"] = None # cannot pickle 'generator' object return state diff --git a/src/lightning/pytorch/plugins/io/async_plugin.py b/src/lightning/pytorch/plugins/io/async_plugin.py index 509f40de0f08d..fea76c1c86545 100644 --- a/src/lightning/pytorch/plugins/io/async_plugin.py +++ b/src/lightning/pytorch/plugins/io/async_plugin.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional +from typing import Any from lightning.fabric.plugins import CheckpointIO from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO @@ -28,11 +30,11 @@ class AsyncCheckpointIO(_WrappingCheckpointIO): checkpoint_io: A checkpoint IO plugin that is used as the basis for async checkpointing. """ - def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: + def __init__(self, checkpoint_io: CheckpointIO | None = None) -> None: super().__init__(checkpoint_io) self._executor = ThreadPoolExecutor(max_workers=1) - self._error: Optional[BaseException] = None + self._error: BaseException | None = None def save_checkpoint(self, *args: Any, **kwargs: Any) -> None: """Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``.""" diff --git a/src/lightning/pytorch/plugins/io/wrapper.py b/src/lightning/pytorch/plugins/io/wrapper.py index 78d2ce83e9469..104a351e90814 100644 --- a/src/lightning/pytorch/plugins/io/wrapper.py +++ b/src/lightning/pytorch/plugins/io/wrapper.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from __future__ import annotations + +from typing import Any from lightning.fabric.plugins import CheckpointIO @@ -23,7 +25,7 @@ class _WrappingCheckpointIO(CheckpointIO): checkpoint_io: A checkpoint IO plugin that is used as the basis. """ - def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: + def __init__(self, checkpoint_io: CheckpointIO | None = None) -> None: super().__init__() self._checkpoint_io = checkpoint_io @@ -36,11 +38,11 @@ def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None: self._base_checkpoint_io_configured = True @property - def checkpoint_io(self) -> Optional["CheckpointIO"]: + def checkpoint_io(self) -> CheckpointIO | None: return self._checkpoint_io @checkpoint_io.setter - def checkpoint_io(self, checkpoint_io: "CheckpointIO") -> None: + def checkpoint_io(self, checkpoint_io: CheckpointIO) -> None: assert not isinstance(checkpoint_io, _WrappingCheckpointIO) if self._checkpoint_io is None: @@ -60,7 +62,7 @@ def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None: assert self.checkpoint_io is not None self.checkpoint_io.remove_checkpoint(*args, **kwargs) - def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def load_checkpoint(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Uses the base ``checkpoint_io`` to load the checkpoint.""" assert self.checkpoint_io is not None return self.checkpoint_io.load_checkpoint(*args, **kwargs) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 4ffde4e655b30..3062272e981a3 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -9,8 +9,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import contextmanager -from typing import Any, Callable, cast, Dict, Generator, Literal, Optional, Union +from typing import Any, Callable, cast, Generator, Literal import torch from torch import Tensor @@ -38,7 +40,7 @@ def __init__( self, precision: Literal["16-mixed", "bf16-mixed"], device: str, - scaler: Optional[torch.cuda.amp.GradScaler] = None, + scaler: torch.cuda.amp.GradScaler | None = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): raise ValueError( @@ -56,7 +58,7 @@ def __init__( self.device = device self.scaler = scaler - def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] + def pre_backward(self, tensor: Tensor, module: pl.LightningModule) -> Tensor: # type: ignore[override] if self.scaler is not None: tensor = self.scaler.scale(tensor) return super().pre_backward(tensor, module) @@ -64,7 +66,7 @@ def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, - model: "pl.LightningModule", + model: pl.LightningModule, closure: Callable[[], Any], **kwargs: Any, ) -> Any: @@ -94,7 +96,7 @@ def optimizer_step( # type: ignore[override] def clip_gradients( self, optimizer: Optimizer, - clip_val: Union[int, float] = 0.0, + clip_val: int | float = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: if clip_val > 0 and _optimizer_handles_unscaling(optimizer): @@ -115,11 +117,11 @@ def forward_context(self) -> Generator[None, None, None]: with self.autocast_context_manager(): yield - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index ec030e0d8e6be..7213b76a69c9a 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, Union +from __future__ import annotations + +from typing import Any, Callable, cast, Literal, TYPE_CHECKING from torch import Tensor from torch.optim import LBFGS, Optimizer @@ -58,8 +60,8 @@ def __init__(self, precision: Literal["32-true", "16-mixed", "bf16-mixed"]) -> N def backward( # type: ignore[override] self, tensor: Tensor, - model: "pl.LightningModule", - optimizer: Optional[Steppable], + model: pl.LightningModule, + optimizer: Steppable | None, *args: Any, **kwargs: Any, ) -> None: @@ -77,13 +79,13 @@ def backward( # type: ignore[override] "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" " the backward logic internally." ) - deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model + deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model deepspeed_engine.backward(tensor, *args, **kwargs) def optimizer_step( # type: ignore[override] self, optimizer: Steppable, - model: "pl.LightningModule", + model: pl.LightningModule, closure: Callable[[], Any], **kwargs: Any, ) -> Any: @@ -98,13 +100,13 @@ def optimizer_step( # type: ignore[override] "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" ) # DeepSpeed handles the optimizer step internally - deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model + deepspeed_engine: deepspeed.DeepSpeedEngine = model.trainer.model return deepspeed_engine.step(**kwargs) def clip_gradients( self, optimizer: Optimizer, - clip_val: Union[int, float] = 0.0, + clip_val: int | float = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """DeepSpeed handles gradient clipping internally.""" diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 77fa9c4171a2b..59eb77d46dbdd 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import contextmanager -from typing import Any, cast, Generator, List, Literal, Tuple +from typing import Any, cast, Generator, Literal import torch import torch.nn as nn @@ -75,8 +77,8 @@ class DoublePrecisionPlugin(PrecisionPlugin): precision: Literal["64-true"] = "64-true" def connect( - self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[nn.Module, List["Optimizer"], List[Any]]: + self, model: nn.Module, optimizers: list[Optimizer], lr_schedulers: list[Any] + ) -> tuple[nn.Module, list[Optimizer], list[Any]]: """Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert incoming floating point data to double (``torch.float64``) precision. diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 5828c35d521d3..800e683188221 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from contextlib import contextmanager -from typing import Any, Generator, Literal, Optional +from typing import Any, Generator, Literal import torch @@ -35,7 +37,7 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin): """ def __init__( - self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[ShardedGradScaler] = None + self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: ShardedGradScaler | None = None ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.") @@ -54,7 +56,7 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: ) @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: + def mixed_precision_config(self) -> MixedPrecision | None: assert MixedPrecision is not None if self.precision == "16-mixed": diff --git a/src/lightning/pytorch/plugins/precision/precision_plugin.py b/src/lightning/pytorch/plugins/precision/precision_plugin.py index 89fa734013083..2bf90d63f61ce 100644 --- a/src/lightning/pytorch/plugins/precision/precision_plugin.py +++ b/src/lightning/pytorch/plugins/precision/precision_plugin.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import contextlib from functools import partial -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Generator import torch from torch import Tensor @@ -35,12 +37,12 @@ class PrecisionPlugin(FabricPrecision, CheckpointHooks): """ def connect( - self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[Module, List[Optimizer], List[Any]]: + self, model: Module, optimizers: list[Optimizer], lr_schedulers: list[Any] + ) -> tuple[Module, list[Optimizer], list[Any]]: """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers - def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] + def pre_backward(self, tensor: Tensor, module: pl.LightningModule) -> Tensor: # type: ignore[override] trainer = module.trainer call._call_callback_hooks(trainer, "on_before_backward", tensor) call._call_lightning_module_hook(trainer, "on_before_backward", tensor) @@ -49,8 +51,8 @@ def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: def backward( # type: ignore[override] self, tensor: Tensor, - model: "pl.LightningModule", - optimizer: Optional[Steppable], + model: pl.LightningModule, + optimizer: Steppable | None, *args: Any, **kwargs: Any, ) -> None: @@ -66,7 +68,7 @@ def backward( # type: ignore[override] """ model.backward(tensor, *args, **kwargs) - def post_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: # type: ignore[override] + def post_backward(self, tensor: Tensor, module: pl.LightningModule) -> Tensor: # type: ignore[override] # once backward has been applied, release graph closure_loss = tensor.detach() trainer = module.trainer @@ -74,7 +76,7 @@ def post_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: call._call_lightning_module_hook(trainer, "on_after_backward") return closure_loss - def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable) -> None: + def _after_closure(self, model: pl.LightningModule, optimizer: Steppable) -> None: """Utility to share some code after the closure has been run.""" trainer = model.trainer call._call_callback_hooks(trainer, "on_before_optimizer_step", optimizer) @@ -88,7 +90,7 @@ def _after_closure(self, model: "pl.LightningModule", optimizer: Steppable) -> N def _wrap_closure( self, - model: "pl.LightningModule", + model: pl.LightningModule, optimizer: Optimizer, closure: Callable[[], Any], ) -> Any: @@ -105,7 +107,7 @@ def _wrap_closure( def optimizer_step( # type: ignore[override] self, optimizer: Steppable, - model: "pl.LightningModule", + model: pl.LightningModule, closure: Callable[[], Any], **kwargs: Any, ) -> Any: @@ -115,10 +117,10 @@ def optimizer_step( # type: ignore[override] def _clip_gradients( self, - model: Union["pl.LightningModule", Module], + model: pl.LightningModule | Module, optimizer: Steppable, - clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None, + clip_val: int | float | None = None, + gradient_clip_algorithm: GradClipAlgorithmType | None = None, ) -> None: if not isinstance(model, pl.LightningModule) or not model.automatic_optimization: # the configuration validator disallows clipping on manual @@ -135,7 +137,7 @@ def _clip_gradients( def clip_gradients( self, optimizer: Optimizer, - clip_val: Union[int, float] = 0.0, + clip_val: int | float = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """Clips the gradients.""" @@ -146,12 +148,12 @@ def clip_gradients( elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: self.clip_grad_by_norm(optimizer, clip_val) - def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_value(self, optimizer: Optimizer, clip_val: int | float) -> None: """Clip gradients by value.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: int | float) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py index c8bae9cc6845e..027c4d0a20286 100644 --- a/src/lightning/pytorch/plugins/precision/xla.py +++ b/src/lightning/pytorch/plugins/precision/xla.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from functools import partial from typing import Any, Callable @@ -39,7 +41,7 @@ def _tpu_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) def optimizer_step( # type: ignore[override] self, optimizer: Optimizable, - model: "pl.LightningModule", + model: pl.LightningModule, closure: Callable[[], Any], **kwargs: Any, ) -> Any: diff --git a/src/lightning/pytorch/plugins/precision/xlabf16.py b/src/lightning/pytorch/plugins/precision/xlabf16.py index e9c0bdd1feda9..d7ab6b185076e 100644 --- a/src/lightning/pytorch/plugins/precision/xlabf16.py +++ b/src/lightning/pytorch/plugins/precision/xlabf16.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os -from typing import Any, List, Literal, Tuple +from typing import Any, Literal import torch.nn as nn from torch.optim import Optimizer @@ -26,8 +28,8 @@ class XLABf16PrecisionPlugin(XLAPrecisionPlugin): precision: Literal["bf16-mixed"] = "bf16-mixed" def connect( - self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[nn.Module, List[Optimizer], List[Any]]: + self, model: nn.Module, optimizers: list[Optimizer], lr_schedulers: list[Any] + ) -> tuple[nn.Module, list[Optimizer], list[Any]]: os.environ["XLA_USE_BF16"] = "1" return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index cc1600af34784..7fdc3cc291158 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" +from __future__ import annotations + import cProfile import io import logging import pstats from pathlib import Path -from typing import Dict, Optional, Tuple, Union from lightning.pytorch.profilers.profiler import Profiler @@ -33,8 +34,8 @@ class AdvancedProfiler(Profiler): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, line_count_restriction: float = 1.0, ) -> None: """ @@ -55,7 +56,7 @@ def __init__( If you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: Dict[str, cProfile.Profile] = {} + self.profiled_actions: dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction def start(self, action_name: str) -> None: @@ -78,11 +79,11 @@ def summary(self) -> str: recorded_stats[action_name] = s.getvalue() return self._stats_to_str(recorded_stats) - def teardown(self, stage: Optional[str]) -> None: + def teardown(self, stage: str | None) -> None: super().teardown(stage=stage) self.profiled_actions = {} - def __reduce__(self) -> Tuple: + def __reduce__(self) -> tuple: # avoids `TypeError: cannot pickle 'cProfile.Profile' object` return ( self.__class__, diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index 5bc23251a873a..bb959a9a361fb 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" +from __future__ import annotations + import logging import os from abc import ABC, abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union +from typing import Any, Callable, Generator, TextIO from lightning.fabric.utilities.cloud_io import get_filesystem @@ -29,16 +31,16 @@ class Profiler(ABC): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, ) -> None: self.dirpath = dirpath self.filename = filename - self._output_file: Optional[TextIO] = None - self._write_stream: Optional[Callable] = None - self._local_rank: Optional[int] = None - self._stage: Optional[str] = None + self._output_file: TextIO | None = None + self._write_stream: Callable | None = None + self._local_rank: int | None = None + self._stage: str | None = None @abstractmethod def start(self, action_name: str) -> None: @@ -73,9 +75,7 @@ def _rank_zero_info(self, *args: Any, **kwargs: Any) -> None: if self._local_rank in (None, 0): log.info(*args, **kwargs) - def _prepare_filename( - self, action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-" - ) -> str: + def _prepare_filename(self, action_name: str | None = None, extension: str = ".txt", split_token: str = "-") -> str: args = [] if self._stage is not None: args.append(self._stage) @@ -113,7 +113,7 @@ def describe(self) -> None: self._output_file.flush() self.teardown(stage=self._stage) - def _stats_to_str(self, stats: Dict[str, str]) -> str: + def _stats_to_str(self, stats: dict[str, str]) -> str: stage = f"{self._stage.upper()} " if self._stage is not None else "" output = [stage + "Profiler Report"] for action, value in stats.items(): @@ -124,13 +124,13 @@ def _stats_to_str(self, stats: Dict[str, str]) -> str: output.append(value) return os.linesep.join(output) - def setup(self, stage: str, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None: + def setup(self, stage: str, local_rank: int | None = None, log_dir: str | None = None) -> None: """Execute arbitrary pre-profiling set-up steps.""" self._stage = stage self._local_rank = local_rank self.dirpath = self.dirpath or log_dir - def teardown(self, stage: Optional[str]) -> None: + def teardown(self, stage: str | None) -> None: """Execute arbitrary post-profiling tear-down steps. Closes the currently open file and stream. diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index fe3ab1c18968c..2d7543a315d02 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" +from __future__ import annotations + import inspect import logging import os from functools import lru_cache, partial from pathlib import Path -from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union +from typing import Any, Callable, ContextManager, TYPE_CHECKING, Union import torch from torch import nn, Tensor @@ -64,8 +66,8 @@ class RegisterRecordFunction: def __init__(self, model: nn.Module) -> None: self._model = model - self._records: Dict[str, record_function] = {} - self._handles: Dict[str, List["RemovableHandle"]] = {} + self._records: dict[str, record_function] = {} + self._handles: dict[str, list[RemovableHandle]] = {} def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: # Add [pl][module] in name for pytorch profiler to recognize @@ -120,9 +122,9 @@ def reset(self) -> None: self._test_step_reached_end = False self._predict_step_reached_end = False # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. - self._current_action: Optional[str] = None - self._prev_schedule_action: Optional[ProfilerAction] = None - self._start_action_name: Optional[str] = None + self._current_action: str | None = None + self._prev_schedule_action: ProfilerAction | None = None + self._start_action_name: str | None = None def setup(self, start_action_name: str) -> None: self._start_action_name = start_action_name @@ -189,7 +191,7 @@ def has_finished(self) -> bool: return self._predict_step_reached_end return False - def __call__(self, num_step: int) -> "ProfilerAction": + def __call__(self, num_step: int) -> ProfilerAction: # ignore the provided input. Keep internal state instead. if self._current_action is None or self.has_finished: return ProfilerAction.NONE @@ -230,15 +232,15 @@ class PyTorchProfiler(Profiler): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, group_by_input_shapes: bool = False, emit_nvtx: bool = False, export_to_chrome: bool = True, row_limit: int = 20, - sort_by_key: Optional[str] = None, + sort_by_key: str | None = None, record_module_names: bool = True, - table_kwargs: Optional[Dict[str, Any]] = None, + table_kwargs: dict[str, Any] | None = None, **profiler_kwargs: Any, ) -> None: r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. @@ -300,14 +302,14 @@ def __init__( self._profiler_kwargs = profiler_kwargs self._table_kwargs = table_kwargs if table_kwargs is not None else {} - self.profiler: Optional[_PROFILER] = None - self.function_events: Optional["EventList"] = None - self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector - self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[ContextManager] = None - self._recording_map: Dict[str, record_function] = {} - self._start_action_name: Optional[str] = None - self._schedule: Optional[ScheduleWrapper] = None + self.profiler: _PROFILER | None = None + self.function_events: EventList | None = None + self._lightning_module: LightningModule | None = None # set by ProfilerConnector + self._register: RegisterRecordFunction | None = None + self._parent_profiler: ContextManager | None = None + self._recording_map: dict[str, record_function] = {} + self._start_action_name: str | None = None + self._schedule: ScheduleWrapper | None = None if _KINETO_AVAILABLE: self._init_kineto(profiler_kwargs) @@ -356,7 +358,7 @@ def _init_kineto(self, profiler_kwargs: Any) -> None: self._profiler_kwargs["with_stack"] = with_stack @property - def _total_steps(self) -> Union[int, float]: + def _total_steps(self) -> int | float: assert self._schedule is not None assert self._lightning_module is not None trainer = self._lightning_module.trainer @@ -393,14 +395,14 @@ def _should_override_schedule(self) -> bool: @staticmethod @lru_cache(1) - def _default_schedule() -> Optional[Callable]: + def _default_schedule() -> Callable | None: if _KINETO_AVAILABLE: # Those schedule defaults allow the profiling overhead to be negligible over training time. return torch.profiler.schedule(wait=1, warmup=1, active=3) return None - def _default_activities(self) -> List["ProfilerActivity"]: - activities: List["ProfilerActivity"] = [] + def _default_activities(self) -> list[ProfilerActivity]: + activities: list[ProfilerActivity] = [] if not _KINETO_AVAILABLE: return activities if self._profiler_kwargs.get("use_cpu", True): @@ -520,7 +522,7 @@ def _create_profilers(self) -> None: torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile ) - def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + def _create_profiler(self, profiler: type[_PROFILER]) -> _PROFILER: init_parameters = inspect.signature(profiler.__init__).parameters kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} return profiler(**kwargs) @@ -553,7 +555,7 @@ def _delete_profilers(self) -> None: self._register.__exit__(None, None, None) self._register = None - def teardown(self, stage: Optional[str]) -> None: + def teardown(self, stage: str | None) -> None: self._delete_profilers() for k in list(self._recording_map): diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py index 3af44d4178ab5..7475cc30e3b2e 100644 --- a/src/lightning/pytorch/profilers/simple.py +++ b/src/lightning/pytorch/profilers/simple.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Profiler to check if there are any bottlenecks in your code.""" +from __future__ import annotations + import logging import os import time from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Tuple import torch @@ -37,8 +39,8 @@ class SimpleProfiler(Profiler): def __init__( self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, + dirpath: str | Path | None = None, + filename: str | None = None, extended: bool = True, ) -> None: """ @@ -59,8 +61,8 @@ def __init__( if you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.current_actions: Dict[str, float] = {} - self.recorded_durations: Dict = defaultdict(list) + self.current_actions: dict[str, float] = {} + self.recorded_durations: dict = defaultdict(list) self.extended = extended self.start_time = time.monotonic() @@ -77,7 +79,7 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]: + def _make_report_extended(self) -> tuple[_TABLE_DATA_EXTENDED, float, float]: total_duration = time.monotonic() - self.start_time report = [] diff --git a/src/lightning/pytorch/profilers/xla.py b/src/lightning/pytorch/profilers/xla.py index b6ebe70fd283e..cc60c766876e0 100644 --- a/src/lightning/pytorch/profilers/xla.py +++ b/src/lightning/pytorch/profilers/xla.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging -from typing import Dict from lightning.fabric.accelerators.xla import _XLA_AVAILABLE from lightning.pytorch.profilers.profiler import Profiler @@ -42,8 +43,8 @@ def __init__(self, port: int = 9012) -> None: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(dirpath=None, filename=None) self.port = port - self._recording_map: Dict = {} - self._step_recoding_map: Dict = {} + self._recording_map: dict = {} + self._step_recoding_map: dict = {} self._start_trace: bool = False def start(self, action_name: str) -> None: diff --git a/src/lightning/pytorch/serve/servable_module.py b/src/lightning/pytorch/serve/servable_module.py index 33efa9956a16f..742ae3247b9e6 100644 --- a/src/lightning/pytorch/serve/servable_module.py +++ b/src/lightning/pytorch/serve/servable_module.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable import torch from torch import Tensor @@ -55,11 +57,11 @@ def configure_response(self): """ @abstractmethod - def configure_payload(self) -> Dict[str, Any]: + def configure_payload(self) -> dict[str, Any]: """Returns a request payload as a dictionary.""" @abstractmethod - def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callable]]: + def configure_serialization(self) -> tuple[dict[str, Callable], dict[str, Callable]]: """Returns a tuple of dictionaries. The first dictionary contains the name of the ``serve_step`` input variables name as its keys @@ -70,7 +72,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab """ @abstractmethod - def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]: + def serve_step(self, *args: Tensor, **kwargs: Tensor) -> dict[str, Tensor]: r"""Returns the predictions of your model as a dictionary. .. code-block:: python @@ -87,5 +89,5 @@ def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ @abstractmethod - def configure_response(self) -> Dict[str, Any]: + def configure_response(self) -> dict[str, Any]: """Returns a response to validate the server response.""" diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py index e6db99091eca6..d5003c1c760cf 100644 --- a/src/lightning/pytorch/serve/servable_module_validator.py +++ b/src/lightning/pytorch/serve/servable_module_validator.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import contextlib import logging import time from multiprocessing import Process -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal import requests import torch @@ -40,7 +42,7 @@ class ServableModuleValidator(Callback): def __init__( self, - optimization: Optional[Literal["trace", "script", "onnx", "tensorrt"]] = None, + optimization: Literal["trace", "script", "onnx", "tensorrt"] | None = None, server: Literal["fastapi", "ml_server", "torchserve", "sagemaker"] = "fastapi", host: str = "127.0.0.1", port: int = 8080, @@ -68,10 +70,10 @@ def __init__( self.server = server self.timeout = timeout self.exit_on_failure = exit_on_failure - self.resp: Optional[requests.Response] = None + self.resp: requests.Response | None = None @rank_zero_only - def on_train_start(self, trainer: "pl.Trainer", servable_module: "pl.LightningModule") -> None: + def on_train_start(self, trainer: pl.Trainer, servable_module: pl.LightningModule) -> None: if isinstance(trainer.strategy, _NOT_SUPPORTED_STRATEGIES): raise Exception( f"The current strategy {trainer.strategy.__class__.__qualname__} used " @@ -128,11 +130,11 @@ def on_train_start(self, trainer: "pl.Trainer", servable_module: "pl.LightningMo _logger.info(f"Your model is servable and the received payload was {self.resp.json()}.") @property - def successful(self) -> Optional[bool]: + def successful(self) -> bool | None: """Returns whether the model was successfully served.""" return self.resp.status_code == 200 if self.resp else None - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"successful": self.successful, "optimization": self.optimization, "server": self.server} @staticmethod @@ -153,7 +155,7 @@ def ping() -> bool: return True @app.post("/serve") - async def serve(payload: dict = Body(...)) -> Dict[str, Any]: + async def serve(payload: dict = Body(...)) -> dict[str, Any]: body = payload["body"] for key, deserializer in deserializers.items(): diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 44faadf166707..c034f09b06b9e 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging from contextlib import nullcontext from datetime import timedelta -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Literal import torch import torch.distributed @@ -68,17 +70,17 @@ class DDPStrategy(ParallelStrategy): def __init__( self, - accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, - ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[Callable] = None, - ddp_comm_wrapper: Optional[Callable] = None, - model_averaging_period: Optional[int] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, + accelerator: pl.accelerators.Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, + ddp_comm_state: object | None = None, + ddp_comm_hook: Callable | None = None, + ddp_comm_wrapper: Callable | None = None, + model_averaging_period: int | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, start_method: Literal["popen", "spawn", "fork", "forkserver"] = "popen", **kwargs: Any, ) -> None: @@ -97,9 +99,9 @@ def __init__( self._ddp_comm_hook = ddp_comm_hook self._ddp_comm_wrapper = ddp_comm_wrapper self._model_averaging_period = model_averaging_period - self._model_averager: Optional[ModelAverager] = None - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout + self._model_averager: ModelAverager | None = None + self._process_group_backend: str | None = process_group_backend + self._timeout: timedelta | None = timeout self._start_method = start_method @property @@ -129,11 +131,11 @@ def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend def _configure_launcher(self) -> None: @@ -147,7 +149,7 @@ def setup_environment(self) -> None: self.setup_distributed() super().setup_environment() - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None self.accelerator.setup(trainer) @@ -249,7 +251,7 @@ def optimizer_step( self, optimizer: Optimizer, closure: Callable[[], Any], - model: Optional[Union["pl.LightningModule", Module]] = None, + model: pl.LightningModule | Module | None = None, **kwargs: Any, ) -> Any: """Performs the actual optimizer step. @@ -276,7 +278,7 @@ def configure_ddp(self) -> None: self.model = self._setup_model(self.model) self._register_ddp_hooks() - def determine_ddp_device_ids(self) -> Optional[List[int]]: + def determine_ddp_device_ids(self) -> list[int] | None: if self.root_device.type == "cpu": return None return [self.root_device.index] @@ -311,9 +313,7 @@ def model_to_device(self) -> None: assert self.model is not None self.model.to(self.root_device) - def reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: + def reduce(self, tensor: Tensor, group: Any | None = None, reduce_op: ReduceOp | str | None = "mean") -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -404,12 +404,12 @@ def teardown(self) -> None: class _DDPForwardRedirection(_ForwardRedirection): - def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: + def on_after_inner_forward(self, wrapper_module: Module, original_module: pl.LightningModule) -> None: # In manual_optimization, we need to prevent DDP reducer as # it is done manually in `LightningModule.manual_backward` if isinstance(wrapper_module, DistributedDataParallel) and not original_module.automatic_optimization: wrapper_module.require_backward_grad_sync = False - def on_after_outer_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: + def on_after_outer_forward(self, wrapper_module: Module, original_module: pl.LightningModule) -> None: if isinstance(wrapper_module, DistributedDataParallel) and not original_module.automatic_optimization: wrapper_module.require_backward_grad_sync = True diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4a8a30a18fe27..4577a0b9c6eb6 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import argparse import contextlib import json @@ -19,7 +21,7 @@ import platform from collections import OrderedDict from pathlib import Path -from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Generator, Mapping, TYPE_CHECKING import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -71,7 +73,7 @@ class DeepSpeedStrategy(DDPStrategy): def __init__( self, - accelerator: Optional["pl.accelerators.Accelerator"] = None, + accelerator: pl.accelerators.Accelerator | None = None, zero_optimization: bool = True, stage: int = 2, remote_device: str = "cpu", @@ -98,11 +100,11 @@ def __init__( allgather_bucket_size: int = 200_000_000, reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, - logging_batch_size_per_gpu: Union[str, int] = "auto", - config: Optional[Union[_PATH, Dict[str, Any]]] = None, + logging_batch_size_per_gpu: str | int = "auto", + config: _PATH | dict[str, Any] | None = None, logging_level: int = logging.WARN, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, loss_scale: float = 0, initial_scale_power: int = 16, loss_scale_window: int = 1000, @@ -113,8 +115,8 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, - precision_plugin: Optional[PrecisionPlugin] = None, - process_group_backend: Optional[str] = None, + precision_plugin: PrecisionPlugin | None = None, + process_group_backend: str | None = None, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -310,7 +312,7 @@ def __init__( self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale - def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: + def _load_config(self, config: _PATH | dict[str, Any] | None) -> dict[str, Any] | None: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -332,7 +334,7 @@ def setup_distributed(self) -> None: self._format_config() self._config_initialized = True - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None self.accelerator.setup(trainer) # we set the device so that optimizers can be created with distributed comms. @@ -372,8 +374,8 @@ def restore_checkpoint_after_setup(self) -> bool: return True def _setup_model_and_optimizers( - self, model: Module, optimizers: List[Optimizer] - ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: + self, model: Module, optimizers: list[Optimizer] + ) -> tuple[deepspeed.DeepSpeedEngine, list[Optimizer]]: """Setup a model and multiple optimizers together. Currently only a single optimizer is supported. @@ -400,9 +402,9 @@ def _setup_model_and_optimizers( def _setup_model_and_optimizer( self, model: Module, - optimizer: Optional[Optimizer], - lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, - ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: + optimizer: Optimizer | None, + lr_scheduler: LRScheduler | ReduceLROnPlateau | None = None, + ) -> tuple[deepspeed.DeepSpeedEngine, Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. This calls :func:`deepspeed.initialize` internally. @@ -447,7 +449,7 @@ def init_deepspeed(self) -> None: else: self._initialize_deepspeed_inference(self.model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig]]: + def _init_optimizers(self) -> tuple[Optimizer, LRSchedulerConfig | None]: assert self.lightning_module is not None optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(lr_schedulers) > 1: @@ -565,10 +567,10 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: self.model = model @property - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} - def setup_optimizers(self, trainer: "pl.Trainer") -> None: + def setup_optimizers(self, trainer: pl.Trainer) -> None: """Creates optimizers and schedulers. Args: @@ -661,7 +663,7 @@ def _create_default_config( self, zero_optimization: bool, zero_allow_untested_optimizer: bool, - logging_batch_size_per_gpu: Union[str, int], + logging_batch_size_per_gpu: str | int, partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, @@ -682,7 +684,7 @@ def _create_default_config( overlap_events: bool, thread_count: int, **zero_kwargs: Any, - ) -> Dict: + ) -> dict: cfg = { "activation_checkpointing": { "partition_activations": partition_activations, @@ -727,14 +729,14 @@ def _create_default_config( return cfg @property - def deepspeed_engine(self) -> "deepspeed.DeepSpeedEngine": + def deepspeed_engine(self) -> deepspeed.DeepSpeedEngine: return self.model @property def _multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 - def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Any | None = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -769,7 +771,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint") - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing @@ -823,9 +825,9 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: assert self.lightning_module is not None def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it @@ -893,6 +895,6 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + def batch_to_device(self, batch: Any, device: torch.device | None = None, dataloader_idx: int = 0) -> Any: batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision) return super().batch_to_device(batch, device, dataloader_idx) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 7864d857e91cc..3722564d10fcf 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import contextlib import logging from datetime import timedelta -from typing import Any, Dict, Generator, List, Optional, Type, Union +from typing import Any, Generator import torch from torch import Tensor @@ -103,20 +105,20 @@ class FSDPStrategy(ParallelStrategy): """ strategy_name = "fsdp" - _registered_strategies: List[str] = [] + _registered_strategies: list[str] = [] def __init__( self, - accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, - cpu_offload: Union[bool, "CPUOffload", None] = None, - mixed_precision: Optional[MixedPrecision] = None, - activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + accelerator: pl.accelerators.Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, + process_group_backend: str | None = None, + timeout: timedelta | None = default_pg_timeout, + cpu_offload: bool | CPUOffload | None = None, + mixed_precision: MixedPrecision | None = None, + activation_checkpointing: type[Module] | list[type[Module]] | None = None, **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -132,7 +134,7 @@ def __init__( self._process_group = None self.num_nodes = 1 self._process_group_backend = process_group_backend - self._timeout: Optional[timedelta] = timeout + self._timeout: timedelta | None = timeout self.cpu_offload = _init_cpu_offload(cpu_offload) self.mixed_precision = mixed_precision if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13: @@ -147,7 +149,7 @@ def __init__( # `self.trainer.model.parameters()` and enables support for multiple parameter groups. self.kwargs.setdefault("use_orig_params", True) - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: """Gathers the full state dict by unsharding all the parameters. To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty @@ -172,18 +174,18 @@ def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property - def process_group(self) -> Optional[ProcessGroup]: + def process_group(self) -> ProcessGroup | None: if self._process_group is None: # The strategy should have already initilized process group in setup_environment() self._process_group = _get_default_group() return self._process_group @property - def process_group_backend(self) -> Optional[str]: + def process_group_backend(self) -> str | None: return self._process_group_backend @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: + def mixed_precision_config(self) -> MixedPrecision | None: if self.mixed_precision: return self.mixed_precision plugin = self.precision_plugin @@ -192,7 +194,7 @@ def mixed_precision_config(self) -> Optional[MixedPrecision]: return None @property - def distributed_sampler_kwargs(self) -> Dict: + def distributed_sampler_kwargs(self) -> dict: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} def setup_environment(self) -> None: @@ -250,7 +252,7 @@ def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel: return wrapped_module - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None assert self.model is not None self.accelerator.setup(trainer) @@ -276,7 +278,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_precision_plugin() - def setup_optimizers(self, trainer: "pl.Trainer") -> None: + def setup_optimizers(self, trainer: pl.Trainer) -> None: if self.kwargs.get("use_orig_params"): return super().setup_optimizers(trainer) @@ -313,7 +315,7 @@ def model_sharded_context(self) -> Generator: ): yield - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: str | None = None) -> None: if not torch.distributed.is_initialized(): return if torch.distributed.get_backend() == "nccl": @@ -331,9 +333,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def reduce( self, - tensor: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", + tensor: Tensor | Any, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. @@ -350,7 +352,7 @@ def reduce( return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def _determine_device_ids(self) -> List[int]: + def _determine_device_ids(self) -> list[int]: return [self.root_device.index] def teardown(self) -> None: @@ -375,7 +377,7 @@ def teardown(self) -> None: self.accelerator.teardown() @classmethod - def get_registered_strategies(cls) -> List[str]: + def get_registered_strategies(cls) -> list[str]: return cls._registered_strategies @classmethod diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index bb8e5a382ad70..2eeb0895790f0 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import queue import tempfile from contextlib import suppress from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union +from typing import Any, Callable, Literal, NamedTuple import numpy as np import torch @@ -63,7 +65,7 @@ class _MultiProcessingLauncher(_Launcher): """ def __init__( - self, strategy: "pl.strategies.ParallelStrategy", start_method: Literal["spawn", "fork", "forkserver"] = "spawn" + self, strategy: pl.strategies.ParallelStrategy, start_method: Literal["spawn", "fork", "forkserver"] = "spawn" ) -> None: self._strategy = strategy self._start_method = start_method @@ -72,7 +74,7 @@ def __init__( f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - self.procs: List[mp.Process] = [] + self.procs: list[mp.Process] = [] @property def is_interactive_compatible(self) -> bool: @@ -81,7 +83,7 @@ def is_interactive_compatible(self) -> bool: # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550 return self._start_method == "fork" - def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: + def launch(self, function: Callable, *args: Any, trainer: pl.Trainer | None = None, **kwargs: Any) -> Any: """Launches processes that run the given function in parallel. The function is allowed to have a return value. However, when all processes join, only the return value @@ -134,12 +136,12 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] def _wrapping_function( self, process_idx: int, - trainer: Optional["pl.Trainer"], + trainer: pl.Trainer | None, function: Callable, args: Any, kwargs: Any, - return_queue: Union[mp.SimpleQueue, queue.Queue], - global_states: Optional["_GlobalStateSnapshot"] = None, + return_queue: mp.SimpleQueue | queue.Queue, + global_states: _GlobalStateSnapshot | None = None, ) -> None: if global_states: global_states.restore() @@ -152,7 +154,7 @@ def _wrapping_function( if process_idx == 0: return_queue.put(move_data_to_device(results, "cpu")) - def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", trainer: "pl.Trainer") -> None: + def _recover_results_in_main_process(self, worker_output: _WorkerOutput, trainer: pl.Trainer) -> None: # transfer back the best path to the trainer if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"): trainer.checkpoint_callback.best_model_path = str(worker_output.best_model_path) @@ -171,7 +173,7 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train # get the `callback_metrics` and set it to the trainer self.update_main_process_results(trainer, worker_output.extra) - def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]: + def _collect_rank_zero_results(self, trainer: pl.Trainer, results: Any) -> _WorkerOutput | None: rank_zero_debug("Collecting results from rank 0 process.") checkpoint_callback = trainer.checkpoint_callback best_model_path = ( @@ -208,7 +210,7 @@ def _check_torchdistx_support(self) -> None: f" initialization when `start_method='spawn'`." ) - def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: + def get_extra_results(self, trainer: pl.Trainer) -> dict[str, Any]: """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To avoid issues with memory sharing, we cast the data to numpy. @@ -224,7 +226,7 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: ) # send as numpy to avoid issues with memory sharing return {"callback_metrics": callback_metrics} - def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None: + def update_main_process_results(self, trainer: pl.Trainer, extra: dict[str, Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. @@ -244,18 +246,18 @@ def kill(self, signum: _SIGNUM) -> None: with suppress(ProcessLookupError): os.kill(proc.pid, signum) - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() state["procs"] = [] # SpawnProcess can't be pickled return state class _WorkerOutput(NamedTuple): - best_model_path: Optional[_PATH] - weights_path: Optional[_PATH] + best_model_path: _PATH | None + weights_path: _PATH | None trainer_state: TrainerState trainer_results: Any - extra: Dict[str, Any] + extra: dict[str, Any] @dataclass @@ -279,10 +281,10 @@ class _GlobalStateSnapshot: use_deterministic_algorithms: bool use_deterministic_algorithms_warn_only: bool cudnn_benchmark: bool - rng_states: Dict[str, Any] + rng_states: dict[str, Any] @classmethod - def capture(cls) -> "_GlobalStateSnapshot": + def capture(cls) -> _GlobalStateSnapshot: """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker process.""" return cls( diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index a46f09f93022a..1041f74e06f40 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import subprocess -from typing import Any, Callable, List, Optional +from typing import Any, Callable from lightning_utilities.core.imports import RequirementCache @@ -70,13 +72,13 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.procs: List[subprocess.Popen] = [] # launched subprocesses. does not include the launcher + self.procs: list[subprocess.Popen] = [] # launched subprocesses. does not include the launcher @property def is_interactive_compatible(self) -> bool: return False - def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: + def launch(self, function: Callable, *args: Any, trainer: pl.Trainer | None = None, **kwargs: Any) -> Any: """Creates new processes, then calls the given function. Arguments: @@ -119,7 +121,7 @@ def _call_children_scripts(self) -> None: del env_copy["PL_GLOBAL_SEED"] hydra_in_use = False - cwd: Optional[str] = None + cwd: str | None = None if _HYDRA_AVAILABLE: from hydra.core.hydra_config import HydraConfig diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index d56a218ae0e64..6ce3e01d3b22e 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os import queue -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import torch.multiprocessing as mp @@ -46,7 +48,7 @@ class _XLALauncher(_MultiProcessingLauncher): strategy: A reference to the strategy that is used together with this launcher """ - def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None: + def __init__(self, strategy: pl.strategies.XLAStrategy) -> None: if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(strategy=strategy, start_method="fork") @@ -55,7 +57,7 @@ def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None: def is_interactive_compatible(self) -> bool: return True - def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: + def launch(self, function: Callable, *args: Any, trainer: pl.Trainer | None = None, **kwargs: Any) -> Any: """Launches processes that run the given function in parallel. The function is allowed to have a return value. However, when all processes join, only the return value @@ -72,7 +74,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] using_pjrt = pjrt.using_pjrt() # pjrt requires that the queue is serializable - return_queue: Union[queue.Queue, mp.SimpleQueue] = ( + return_queue: queue.Queue | mp.SimpleQueue = ( mp.Manager().Queue() if using_pjrt else mp.get_context(self._start_method).SimpleQueue() ) @@ -110,12 +112,12 @@ def _wrapping_function( # XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing # https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321 process_idx: int, - trainer: Optional["pl.Trainer"], + trainer: pl.Trainer | None, function: Callable, args: Any, kwargs: Any, - return_queue: Union[mp.SimpleQueue, queue.Queue], - global_states: Optional[_GlobalStateSnapshot] = None, + return_queue: mp.SimpleQueue | queue.Queue, + global_states: _GlobalStateSnapshot | None = None, ) -> None: import torch_xla.core.xla_model as xm from torch_xla.experimental import pjrt @@ -137,7 +139,7 @@ def _wrapping_function( _rank_teardown(self._strategy.local_rank) - def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]: + def _collect_rank_zero_results(self, trainer: pl.Trainer, results: Any) -> _WorkerOutput | None: rank_zero_debug("Collecting results from rank 0 process.") checkpoint_callback = trainer.checkpoint_callback best_model_path = ( diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index 8f3b991e144bc..f16559ce22fa4 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Generator import torch from torch import Tensor @@ -31,16 +33,16 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, - accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + accelerator: pl.accelerators.Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + cluster_environment: ClusterEnvironment | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, ): super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices - self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment - self._layer_sync: Optional[LayerSync] = None + self.cluster_environment: ClusterEnvironment | None = cluster_environment + self._layer_sync: LayerSync | None = None @property @abstractmethod @@ -68,21 +70,21 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[List[torch.device]]: + def parallel_devices(self) -> list[torch.device] | None: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: list[torch.device] | None) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return { "num_replicas": len(self.parallel_devices) if self.parallel_devices is not None else 0, "rank": self.global_rank, } - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) diff --git a/src/lightning/pytorch/strategies/single_xla.py b/src/lightning/pytorch/strategies/single_xla.py index c502ba1a2e7aa..36e85a0d5e79c 100644 --- a/src/lightning/pytorch/strategies/single_xla.py +++ b/src/lightning/pytorch/strategies/single_xla.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os -from typing import Optional import torch @@ -33,9 +34,9 @@ class SingleDeviceXLAStrategy(SingleDeviceStrategy): def __init__( self, device: _DEVICE, - accelerator: Optional["pl.accelerators.Accelerator"] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + accelerator: pl.accelerators.Accelerator | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, debug: bool = False, ): if not _XLA_AVAILABLE: @@ -63,10 +64,10 @@ def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io @checkpoint_io.setter - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: self._checkpoint_io = io - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: assert self.model, "self.model must be set before find_shared_parameters(self.model)" shared_params = find_shared_parameters(self.model) self.model_to_device() diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 0b700735ea295..2fb7b85f6e9b9 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import contextlib import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Generator, Mapping, TypeVar import torch from torch import Tensor @@ -47,31 +49,31 @@ class Strategy(ABC): def __init__( self, - accelerator: Optional["pl.accelerators.Accelerator"] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + accelerator: pl.accelerators.Accelerator | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, ) -> None: - self._accelerator: Optional["pl.accelerators.Accelerator"] = accelerator - self._checkpoint_io: Optional[CheckpointIO] = checkpoint_io - self._precision_plugin: Optional[PrecisionPlugin] = precision_plugin - self._lightning_module: Optional[pl.LightningModule] = None - self._model: Optional[Module] = None - self._launcher: Optional[_Launcher] = None + self._accelerator: pl.accelerators.Accelerator | None = accelerator + self._checkpoint_io: CheckpointIO | None = checkpoint_io + self._precision_plugin: PrecisionPlugin | None = precision_plugin + self._lightning_module: pl.LightningModule | None = None + self._model: Module | None = None + self._launcher: _Launcher | None = None self._forward_redirection: _ForwardRedirection = _ForwardRedirection() - self._optimizers: List[Optimizer] = [] - self._lightning_optimizers: List[LightningOptimizer] = [] - self.lr_scheduler_configs: List[LRSchedulerConfig] = [] + self._optimizers: list[Optimizer] = [] + self._lightning_optimizers: list[LightningOptimizer] = [] + self.lr_scheduler_configs: list[LRSchedulerConfig] = [] @property - def launcher(self) -> Optional[_Launcher]: + def launcher(self) -> _Launcher | None: return self._launcher @property - def accelerator(self) -> Optional["pl.accelerators.Accelerator"]: + def accelerator(self) -> pl.accelerators.Accelerator | None: return self._accelerator @accelerator.setter - def accelerator(self, accelerator: "pl.accelerators.Accelerator") -> None: + def accelerator(self, accelerator: pl.accelerators.Accelerator) -> None: self._accelerator = accelerator @property @@ -84,7 +86,7 @@ def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io @checkpoint_io.setter - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: self._checkpoint_io = io @property @@ -92,19 +94,19 @@ def precision_plugin(self) -> PrecisionPlugin: return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin() @precision_plugin.setter - def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None: + def precision_plugin(self, precision_plugin: PrecisionPlugin | None) -> None: self._precision_plugin = precision_plugin @property - def optimizers(self) -> List[Optimizer]: + def optimizers(self) -> list[Optimizer]: return self._optimizers @optimizers.setter - def optimizers(self, optimizers: List[Optimizer]) -> None: + def optimizers(self, optimizers: list[Optimizer]) -> None: self._optimizers = optimizers self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers] - def connect(self, model: "pl.LightningModule") -> None: + def connect(self, model: pl.LightningModule) -> None: """Called by the accelerator to connect the accelerator and the model with this plugin.""" self._lightning_module = model self.model = model @@ -121,7 +123,7 @@ def setup_environment(self) -> None: assert self.accelerator is not None self.accelerator.setup_device(self.root_device) - def setup_optimizers(self, trainer: "pl.Trainer") -> None: + def setup_optimizers(self, trainer: pl.Trainer) -> None: """Creates optimizers and schedulers. Args: @@ -132,7 +134,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: assert self.lightning_module is not None self.optimizers, self.lr_scheduler_configs = _init_optimizers_and_lr_schedulers(self.lightning_module) - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: """Setup plugins for the trainer fit and creates optimizers. Args: @@ -154,7 +156,7 @@ def setup_precision_plugin(self) -> None: self.optimizers = optimizers self.lr_scheduler_configs = lr_scheduler_configs - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. @@ -174,7 +176,7 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: def backward( self, closure_loss: Tensor, - optimizer: Optional[Optimizer], + optimizer: Optimizer | None, *args: Any, **kwargs: Any, ) -> Tensor: @@ -202,7 +204,7 @@ def optimizer_step( self, optimizer: Optimizer, closure: Callable[[], Any], - model: Optional[Union["pl.LightningModule", Module]] = None, + model: pl.LightningModule | Module | None = None, **kwargs: Any, ) -> Any: r"""Performs the actual optimizer step. @@ -218,7 +220,7 @@ def optimizer_step( assert isinstance(model, pl.LightningModule) return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) - def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: + def _setup_model_and_optimizers(self, model: Module, optimizers: list[Optimizer]) -> tuple[Module, list[Optimizer]]: """Setup a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -239,7 +241,7 @@ def _setup_optimizer(self, optimizer: Optimizer) -> Optimizer: # TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324 return optimizer - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: + def batch_to_device(self, batch: Any, device: torch.device | None = None, dataloader_idx: int = 0) -> Any: """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just @@ -273,10 +275,10 @@ def is_global_zero(self) -> bool: @abstractmethod def reduce( self, - tensor: Union[Tensor, Any], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Union[Tensor, Any]: + tensor: Tensor | Any, + group: Any | None = None, + reduce_op: ReduceOp | str | None = "mean", + ) -> Tensor | Any: """Reduces the given tensor (e.g. across GPUs/processes). Args: @@ -287,7 +289,7 @@ def reduce( """ @abstractmethod - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: str | None = None) -> None: """Synchronizes all processes which blocks processes until the whole group enters this function. Args: @@ -304,7 +306,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: """ @abstractmethod - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Perform an all_gather on all processes. Args: @@ -324,20 +326,20 @@ def post_backward(self, closure_loss: Tensor) -> None: """Run after precision plugin executes backward.""" @property - def model(self) -> Optional[Module]: + def model(self) -> Module | None: """Returns the potentially wrapped LightningModule.""" return self._model if self._model is not None else self._lightning_module @model.setter - def model(self, new_model: Optional[Module]) -> None: + def model(self, new_model: Module | None) -> None: self._model = new_model @property - def lightning_module(self) -> Optional["pl.LightningModule"]: + def lightning_module(self) -> pl.LightningModule | None: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) @@ -370,7 +372,7 @@ def post_training_step(self) -> None: """ pass - def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT | None: """The actual validation step. See :meth:`~lightning.pytorch.core.module.LightningModule.validation_step` for more details @@ -382,7 +384,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs) return self.lightning_module.validation_step(*args, **kwargs) - def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT | None: """The actual test step. See :meth:`~lightning.pytorch.core.module.LightningModule.test_step` for more details @@ -437,14 +439,12 @@ def handles_gradient_accumulation(self) -> bool: """Whether the plugin handles gradient accumulation internally.""" return False - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: """Returns model state.""" assert self.lightning_module is not None return self.lightning_module.state_dict() - def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Any | None = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -533,13 +533,13 @@ def on_exception(self, exception: BaseException) -> None: """Called when the trainer execution is interrupted by an exception.""" pass - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled state = dict(vars(self)) # copy state["_lightning_optimizers"] = [] return state - def __setstate__(self, state: Dict) -> None: + def __setstate__(self, state: dict) -> None: self.__dict__ = state self.optimizers = self.optimizers # re-create the `_lightning_optimizers` @@ -551,7 +551,7 @@ class _ForwardRedirection: """ def __call__( - self, wrapper_module: Module, original_module: "pl.LightningModule", method_name: str, *args: Any, **kwargs: Any + self, wrapper_module: Module, original_module: pl.LightningModule, method_name: str, *args: Any, **kwargs: Any ) -> STEP_OUTPUT: """Reroutes a method call through the `wrapper_module`'s `forward` method. @@ -585,8 +585,8 @@ def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: self.on_after_outer_forward(wrapper_module, original_module) return wrapper_output - def on_after_inner_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: + def on_after_inner_forward(self, wrapper_module: Module, original_module: pl.LightningModule) -> None: pass - def on_after_outer_forward(self, wrapper_module: Module, original_module: "pl.LightningModule") -> None: + def on_after_outer_forward(self, wrapper_module: Module, original_module: pl.LightningModule) -> None: pass diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 5ea1d4bf92fcb..5be2a408410c7 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import io import os -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import torch from torch import Tensor @@ -49,10 +51,10 @@ class XLAStrategy(DDPStrategy): def __init__( self, - accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision_plugin: Optional[PrecisionPlugin] = None, + accelerator: pl.accelerators.Accelerator | None = None, + parallel_devices: list[torch.device] | None = None, + checkpoint_io: CheckpointIO | None = None, + precision_plugin: PrecisionPlugin | None = None, debug: bool = False, sync_module_states: bool = True, **_: Any, @@ -67,7 +69,7 @@ def __init__( precision_plugin=precision_plugin, start_method="fork", ) - self._checkpoint_io: Optional[CheckpointIO] + self._checkpoint_io: CheckpointIO | None self.debug = debug self._launched = False self._sync_module_states = sync_module_states @@ -82,7 +84,7 @@ def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io @checkpoint_io.setter - def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: + def checkpoint_io(self, io: CheckpointIO | None) -> None: self._checkpoint_io = io @property @@ -96,7 +98,7 @@ def root_device(self) -> torch.device: def _configure_launcher(self) -> None: self._launcher = _XLALauncher(self) - def setup(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator self.accelerator.setup(trainer) @@ -123,10 +125,10 @@ def _setup_model(self, model: Module) -> Module: # type: ignore return model @property - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} - def process_dataloader(self, dataloader: object) -> "MpDeviceLoader": + def process_dataloader(self, dataloader: object) -> MpDeviceLoader: from torch_xla.distributed.parallel_loader import MpDeviceLoader if isinstance(dataloader, MpDeviceLoader): @@ -146,7 +148,7 @@ def model_to_device(self) -> None: assert self.model is not None self.model = self.model.to(self.root_device) - def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: + def barrier(self, name: str | None = None, *args: Any, **kwargs: Any) -> None: if not self._launched: return @@ -187,9 +189,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj - def reduce( - self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None - ) -> Tensor: + def reduce(self, output: Tensor | Any, group: Any | None = None, reduce_op: ReduceOp | str | None = None) -> Tensor: if not isinstance(output, Tensor): output = torch.tensor(output, device=self.root_device) @@ -243,7 +243,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None: if self.local_rank == 0: self.checkpoint_io.remove_checkpoint(filepath) - def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: + def all_gather(self, tensor: Tensor, group: Any | None = None, sync_grads: bool = False) -> Tensor: """Function to gather a tensor from several distributed processes. Args: diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 8e5cf8dac976b..72048133ec60a 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable from lightning_utilities.core.imports import module_available from packaging.version import Version @@ -27,7 +29,7 @@ log = logging.getLogger(__name__) -def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any: +def _call_and_handle_interrupt(trainer: pl.Trainer, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any: r"""Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) as all errors should funnel through them. @@ -69,7 +71,7 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg raise -def _call_setup_hook(trainer: "pl.Trainer") -> None: +def _call_setup_hook(trainer: pl.Trainer) -> None: assert trainer.state.fn is not None fn = trainer.state.fn @@ -83,7 +85,7 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None: trainer.strategy.barrier("post_setup") -def _call_configure_sharded_model(trainer: "pl.Trainer") -> None: +def _call_configure_sharded_model(trainer: pl.Trainer) -> None: with trainer.strategy.model_sharded_context(): # experimental support for torchdistx if module_available("torchdistx.deferred_init"): @@ -94,7 +96,7 @@ def _call_configure_sharded_model(trainer: "pl.Trainer") -> None: _call_lightning_module_hook(trainer, "configure_sharded_model") -def _call_teardown_hook(trainer: "pl.Trainer") -> None: +def _call_teardown_hook(trainer: pl.Trainer) -> None: assert trainer.state.fn is not None fn = trainer.state.fn @@ -118,10 +120,10 @@ def _call_teardown_hook(trainer: "pl.Trainer") -> None: def _call_lightning_module_hook( - trainer: "pl.Trainer", + trainer: pl.Trainer, hook_name: str, *args: Any, - pl_module: Optional["pl.LightningModule"] = None, + pl_module: pl.LightningModule | None = None, **kwargs: Any, ) -> Any: pl_module = pl_module or trainer.lightning_module @@ -146,7 +148,7 @@ def _call_lightning_module_hook( def _call_lightning_datamodule_hook( - trainer: "pl.Trainer", + trainer: pl.Trainer, hook_name: str, *args: Any, **kwargs: Any, @@ -162,10 +164,10 @@ def _call_lightning_datamodule_hook( def _call_callback_hooks( - trainer: "pl.Trainer", + trainer: pl.Trainer, hook_name: str, *args: Any, - monitoring_callbacks: Optional[bool] = None, + monitoring_callbacks: bool | None = None, **kwargs: Any, ) -> None: log.debug(f"{trainer.__class__.__name__}: calling callback hook: {hook_name}") @@ -193,7 +195,7 @@ def _call_callback_hooks( pl_module._current_fx_name = prev_fx_name -def _call_callbacks_state_dict(trainer: "pl.Trainer") -> Dict[str, dict]: +def _call_callbacks_state_dict(trainer: pl.Trainer) -> dict[str, dict]: """Called when saving a model checkpoint, calls and returns every callback's `state_dict`, keyed by `Callback.state_key`.""" callback_state_dicts = {} @@ -204,7 +206,7 @@ def _call_callbacks_state_dict(trainer: "pl.Trainer") -> Dict[str, dict]: return callback_state_dicts -def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_on_save_checkpoint(trainer: pl.Trainer, checkpoint: dict[str, Any]) -> None: """Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.""" pl_module = trainer.lightning_module if pl_module: @@ -220,7 +222,7 @@ def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_on_load_checkpoint(trainer: pl.Trainer, checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint. Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using @@ -231,7 +233,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = "on_load_checkpoint" - callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") + callback_states: dict[type | str, dict] | None = checkpoint.get("callbacks") if callback_states is None: return @@ -255,9 +257,9 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_load_state_dict(trainer: pl.Trainer, checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" - callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") + callback_states: dict[type | str, dict] | None = checkpoint.get("callbacks") if callback_states is None: return @@ -270,7 +272,7 @@ def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: Dict[str, def _call_strategy_hook( - trainer: "pl.Trainer", + trainer: pl.Trainer, hook_name: str, *args: Any, **kwargs: Any, diff --git a/src/lightning/pytorch/trainer/configuration_validator.py b/src/lightning/pytorch/trainer/configuration_validator.py index 6695708d4894b..3a06932e7d0db 100644 --- a/src/lightning/pytorch/trainer/configuration_validator.py +++ b/src/lightning/pytorch/trainer/configuration_validator.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.trainer.states import TrainerFn @@ -22,7 +24,7 @@ from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature -def _verify_loop_configurations(trainer: "pl.Trainer") -> None: +def _verify_loop_configurations(trainer: pl.Trainer) -> None: r"""Checks that the model is configured correctly before the run is started. Args: @@ -46,7 +48,7 @@ def _verify_loop_configurations(trainer: "pl.Trainer") -> None: __verify_batch_transfer_support(trainer) -def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: +def __verify_train_val_loop_configuration(trainer: pl.Trainer, model: pl.LightningModule) -> None: # verify minimum training requirements has_training_step = is_overridden("training_step", model) if not has_training_step: @@ -89,7 +91,7 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh ) -def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> None: +def __verify_eval_loop_configuration(model: pl.LightningModule, stage: str) -> None: step_name = "validation_step" if stage == "val" else f"{stage}_step" has_step = is_overridden(step_name, model) @@ -116,7 +118,7 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> ) -def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None: +def __verify_batch_transfer_support(trainer: pl.Trainer) -> None: batch_transfer_hooks = ("transfer_batch_to_device", "on_after_batch_transfer") datahook_selector = trainer._data_connector._datahook_selector assert datahook_selector is not None @@ -132,7 +134,7 @@ def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None: raise MisconfigurationException(f"Overriding `{hook}` is not supported with IPUs.") -def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: +def __verify_manual_optimization_support(trainer: pl.Trainer, model: pl.LightningModule) -> None: if model.automatic_optimization: return if trainer.gradient_clip_val is not None and trainer.gradient_clip_val > 0: @@ -149,7 +151,7 @@ def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.Light ) -def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None: +def __check_training_step_requires_dataloader_iter(model: pl.LightningModule) -> None: """Check if the current `training_step` is requesting `dataloader_iter`.""" if is_param_in_hook_signature(model.training_step, "dataloader_iter", explicit=True): for hook in ("on_train_batch_start", "on_train_batch_end"): diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 42f46cd75047d..55d24fcd050fc 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os from collections import Counter -from typing import Dict, List, Literal, Optional, Union +from typing import Literal import torch @@ -77,16 +79,16 @@ class _AcceleratorConnector: def __init__( self, - devices: Union[List[int], str, int] = "auto", + devices: list[int] | str | int = "auto", num_nodes: int = 1, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, + accelerator: str | Accelerator = "auto", + strategy: str | Strategy = "auto", + plugins: PLUGIN_INPUT | list[PLUGIN_INPUT] | None = None, precision: _PRECISION_INPUT = "32-true", sync_batchnorm: bool = False, - benchmark: Optional[bool] = None, + benchmark: bool | None = None, use_distributed_sampler: bool = True, - deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, + deterministic: bool | _LITERAL_WARN | None = None, ) -> None: """The AcceleratorConnector parses several Trainer arguments and instantiates the Strategy including other components such as the Accelerator and Precision plugins. @@ -126,14 +128,14 @@ def __init__( # Raise an exception if there are conflicts between flags # Set each valid flag to `self._x_flag` after validation - self._strategy_flag: Union[Strategy, str] = "auto" - self._accelerator_flag: Union[Accelerator, str] = "auto" + self._strategy_flag: Strategy | str = "auto" + self._accelerator_flag: Accelerator | str = "auto" self._precision_flag: _PRECISION_INPUT_STR = "32-true" - self._precision_plugin_flag: Optional[PrecisionPlugin] = None - self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: List[Union[int, torch.device, str]] = [] - self._layer_sync: Optional[LayerSync] = TorchSyncBatchNorm() if sync_batchnorm else None - self.checkpoint_io: Optional[CheckpointIO] = None + self._precision_plugin_flag: PrecisionPlugin | None = None + self._cluster_environment_flag: ClusterEnvironment | str | None = None + self._parallel_devices: list[int | torch.device | str] = [] + self._layer_sync: LayerSync | None = TorchSyncBatchNorm() if sync_batchnorm else None + self.checkpoint_io: CheckpointIO | None = None self._check_config_and_set_final_flags( strategy=strategy, @@ -171,10 +173,10 @@ def __init__( def _check_config_and_set_final_flags( self, - strategy: Union[str, Strategy], - accelerator: Union[str, Accelerator], + strategy: str | Strategy, + accelerator: str | Accelerator, precision: _PRECISION_INPUT, - plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]], + plugins: PLUGIN_INPUT | list[PLUGIN_INPUT] | None, sync_batchnorm: bool, ) -> None: """This method checks: @@ -238,7 +240,7 @@ def _check_config_and_set_final_flags( self._precision_flag = _convert_precision_to_unified_args(precision) if plugins: - plugins_flags_types: Dict[str, int] = Counter() + plugins_flags_types: dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin @@ -317,7 +319,7 @@ def _check_config_and_set_final_flags( def _check_device_config_and_set_final_flags( self, - devices: Union[List[int], str, int], + devices: list[int] | str | int, num_nodes: int, ) -> None: self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1 @@ -410,7 +412,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: return BaguaEnvironment() return LightningEnvironment() - def _choose_strategy(self) -> Union[Strategy, str]: + def _choose_strategy(self) -> Strategy | str: if self._accelerator_flag == "ipu": if not _LIGHTNING_GRAPHCORE_AVAILABLE: raise ImportError( @@ -653,9 +655,7 @@ def is_distributed(self) -> bool: return False -def _set_torch_flags( - *, deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, benchmark: Optional[bool] = None -) -> None: +def _set_torch_flags(*, deterministic: bool | _LITERAL_WARN | None = None, benchmark: bool | None = None) -> None: if deterministic: if benchmark is None: # Set benchmark to False to ensure determinism diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index d649755172658..f669738ab3daf 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os from datetime import timedelta -from typing import Dict, List, Optional, Sequence, Union +from typing import Sequence import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks @@ -41,17 +43,17 @@ class _CallbackConnector: - def __init__(self, trainer: "pl.Trainer"): + def __init__(self, trainer: pl.Trainer): self.trainer = trainer def on_trainer_init( self, - callbacks: Optional[Union[List[Callback], Callback]], + callbacks: list[Callback] | Callback | None, enable_checkpointing: bool, enable_progress_bar: bool, - default_root_dir: Optional[str], + default_root_dir: str | None, enable_model_summary: bool, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + max_time: str | timedelta | dict[str, int] | None = None, ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() @@ -139,7 +141,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: progress_bar_callback = TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) - def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: + def _configure_timer_callback(self, max_time: str | timedelta | dict[str, int] | None = None) -> None: if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): @@ -186,7 +188,7 @@ def _attach_model_callbacks(self) -> None: trainer.callbacks = all_callbacks @staticmethod - def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: + def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. @@ -198,9 +200,9 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints if there were any present in the input. """ - tuner_callbacks: List[Callback] = [] - other_callbacks: List[Callback] = [] - checkpoint_callbacks: List[Callback] = [] + tuner_callbacks: list[Callback] = [] + other_callbacks: list[Callback] = [] + checkpoint_callbacks: list[Callback] = [] for cb in callbacks: if isinstance(cb, (BatchSizeFinder, LearningRateFinder)): @@ -213,7 +215,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: return tuner_callbacks + other_callbacks + checkpoint_callbacks -def _validate_callbacks_list(callbacks: List[Callback]) -> None: +def _validate_callbacks_list(callbacks: list[Callback]) -> None: stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] seen_callbacks = set() for callback in stateful_callbacks: diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index a27ebc5c212b1..ee8c3b48bb042 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import re -from typing import Any, Dict, Optional +from typing import Any import torch from fsspec.core import url_to_fs @@ -43,15 +45,15 @@ class _CheckpointConnector: - def __init__(self, trainer: "pl.Trainer") -> None: + def __init__(self, trainer: pl.Trainer) -> None: self.trainer = trainer - self._ckpt_path: Optional[_PATH] = None + self._ckpt_path: _PATH | None = None # flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter` self._user_managed: bool = False - self._loaded_checkpoint: Dict[str, Any] = {} + self._loaded_checkpoint: dict[str, Any] = {} @property - def _hpc_resume_path(self) -> Optional[str]: + def _hpc_resume_path(self) -> str | None: dir_path_hpc = self.trainer.default_root_dir dir_path_hpc = str(dir_path_hpc) fs, path = url_to_fs(dir_path_hpc) @@ -64,7 +66,7 @@ def _hpc_resume_path(self) -> Optional[str]: return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt" return None - def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: + def resume_start(self, checkpoint_path: _PATH | None = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`. @@ -83,8 +85,8 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) def _select_ckpt_path( - self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool - ) -> Optional[_PATH]: + self, state_fn: TrainerFn, ckpt_path: _PATH | None, model_provided: bool, model_connected: bool + ) -> _PATH | None: """Called by the ``Trainer`` to select the checkpoint path source.""" if self._user_managed: if ckpt_path: @@ -113,8 +115,8 @@ def _select_ckpt_path( return ckpt_path def _parse_ckpt_path( - self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool - ) -> Optional[_PATH]: + self, state_fn: TrainerFn, ckpt_path: _PATH | None, model_provided: bool, model_connected: bool + ) -> _PATH | None: """Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer configuration.""" if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None: @@ -223,7 +225,7 @@ def resume_end(self) -> None: # wait for all to catch up self.trainer.strategy.barrier("_CheckpointConnector.resume_end") - def restore(self, checkpoint_path: Optional[_PATH] = None) -> None: + def restore(self, checkpoint_path: _PATH | None = None) -> None: """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore, in this priority: @@ -390,7 +392,7 @@ def restore_lr_schedulers(self) -> None: for config, lrs_state in zip(self.trainer.lr_scheduler_configs, lr_schedulers): config.scheduler.load_state_dict(lrs_state) - def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None: + def _restore_modules_and_callbacks(self, checkpoint_path: _PATH | None = None) -> None: # restore modules after setup self.resume_start(checkpoint_path) self.restore_model() @@ -487,10 +489,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint) return checkpoint - def _get_lightning_module_state_dict(self) -> Dict[str, Tensor]: + def _get_lightning_module_state_dict(self) -> dict[str, Tensor]: return self.trainer.strategy.lightning_module_state_dict() - def _get_loops_state_dict(self) -> Dict[str, Any]: + def _get_loops_state_dict(self) -> dict[str, Any]: return { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), @@ -499,7 +501,7 @@ def _get_loops_state_dict(self) -> Dict[str, Any]: } @staticmethod - def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]: + def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> int | None: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. Args: diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 3152bd9777e83..d90d400a46715 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, Iterable, Optional, Tuple, Union +from typing import Any, Iterable from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -44,15 +46,15 @@ class _DataConnector: - def __init__(self, trainer: "pl.Trainer"): + def __init__(self, trainer: pl.Trainer): self.trainer = trainer - self._datahook_selector: Optional[_DataHookSelector] = None + self._datahook_selector: _DataHookSelector | None = None def on_trainer_init( self, - val_check_interval: Optional[Union[int, float]], + val_check_interval: int | float | None, reload_dataloaders_every_n_epochs: int, - check_val_every_n_epoch: Optional[int], + check_val_every_n_epoch: int | None, ) -> None: self.trainer.datamodule = None @@ -101,12 +103,12 @@ def prepare_data(self) -> None: def attach_data( self, - model: "pl.LightningModule", - train_dataloaders: Optional[TRAIN_DATALOADERS] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - test_dataloaders: Optional[EVAL_DATALOADERS] = None, - predict_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional["pl.LightningDataModule"] = None, + model: pl.LightningModule, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + test_dataloaders: EVAL_DATALOADERS | None = None, + predict_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: pl.LightningDataModule | None = None, ) -> None: # set up the passed in dataloaders (if needed) self.attach_dataloaders( @@ -123,11 +125,11 @@ def attach_data( def attach_dataloaders( self, - model: "pl.LightningModule", - train_dataloaders: Optional[TRAIN_DATALOADERS] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - test_dataloaders: Optional[EVAL_DATALOADERS] = None, - predict_dataloaders: Optional[EVAL_DATALOADERS] = None, + model: pl.LightningModule, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + test_dataloaders: EVAL_DATALOADERS | None = None, + predict_dataloaders: EVAL_DATALOADERS | None = None, ) -> None: trainer = self.trainer @@ -145,9 +147,7 @@ def attach_dataloaders( trainer.test_loop._data_source.instance = test_dataloaders if test_dataloaders is not None else model trainer.predict_loop._data_source.instance = predict_dataloaders if predict_dataloaders is not None else model - def attach_datamodule( - self, model: "pl.LightningModule", datamodule: Optional["pl.LightningDataModule"] = None - ) -> None: + def attach_datamodule(self, model: pl.LightningModule, datamodule: pl.LightningDataModule | None = None) -> None: # If we have a datamodule, attach necessary hooks + dataloaders self._datahook_selector = _DataHookSelector(model, datamodule) @@ -208,8 +208,8 @@ def _prepare_dataloader(self, dataloader: object, shuffle: bool, mode: RunningSt return dataloader def _resolve_sampler( - self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None - ) -> Union[Sampler, Iterable]: + self, dataloader: DataLoader, shuffle: bool, mode: RunningStage | None = None + ) -> Sampler | Iterable: if self._requires_distributed_sampler(dataloader): distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs assert distributed_sampler_kwargs is not None @@ -244,8 +244,8 @@ def _resolve_sampler( def _get_distributed_sampler( dataloader: DataLoader, shuffle: bool, - overfit_batches: Union[int, float], - mode: Optional[RunningStage] = None, + overfit_batches: int | float, + mode: RunningStage | None = None, **kwargs: Any, ) -> DistributedSampler: """This function is used to created the distributed sampler injected within the user DataLoader.""" @@ -291,10 +291,10 @@ class _DataLoaderSource: that returns the desired dataloader(s). """ - instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] + instance: TRAIN_DATALOADERS | EVAL_DATALOADERS | pl.LightningModule | pl.LightningDataModule | None name: str - def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> TRAIN_DATALOADERS | EVAL_DATALOADERS: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. @@ -322,7 +322,7 @@ def is_module(self) -> bool: return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule)) -def _request_dataloader(data_source: _DataLoaderSource) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: +def _request_dataloader(data_source: _DataLoaderSource) -> TRAIN_DATALOADERS | EVAL_DATALOADERS: """Requests a dataloader by calling dataloader hooks corresponding to the given stage. Returns: @@ -350,13 +350,13 @@ class _DataHookSelector: datamodule: A ``LightningDataModule`` """ - model: "pl.LightningModule" - datamodule: Optional["pl.LightningDataModule"] - _valid_hooks: Tuple[str, ...] = field( + model: pl.LightningModule + datamodule: pl.LightningDataModule | None + _valid_hooks: tuple[str, ...] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) - def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.LightningDataModule"]: + def get_instance(self, hook_name: str) -> pl.LightningModule | pl.LightningDataModule: if hook_name not in self._valid_hooks: raise ValueError( f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`." @@ -444,9 +444,7 @@ def _worker_check(dataloader: object, using_spawn: bool, name: str) -> None: ) -def _parse_num_batches( - stage: RunningStage, length: Union[int, float], limit_batches: Union[int, float] -) -> Union[int, float]: +def _parse_num_batches(stage: RunningStage, length: int | float, limit_batches: int | float) -> int | float: if length == 0: return int(length) @@ -473,9 +471,7 @@ def _parse_num_batches( return num_batches -def _process_dataloader( - trainer: "pl.Trainer", trainer_fn: TrainerFn, stage: RunningStage, dataloader: object -) -> object: +def _process_dataloader(trainer: pl.Trainer, trainer_fn: TrainerFn, stage: RunningStage, dataloader: object) -> object: if stage != RunningStage.TRAINING: is_shuffled = _is_dataloader_shuffled(dataloader) # limit this warning only for samplers assigned automatically when shuffle is set diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index a889b192795de..1e223c89cb13c 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from __future__ import annotations from typing_extensions import TypedDict @@ -20,8 +20,8 @@ class _FxValidator: class _LogOptions(TypedDict): - allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]] - allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]] + allowed_on_step: tuple[bool] | tuple[bool, bool] + allowed_on_epoch: tuple[bool] | tuple[bool, bool] default_on_step: bool default_on_epoch: bool @@ -162,9 +162,7 @@ def check_logging(cls, fx_name: str) -> None: ) @classmethod - def get_default_logging_levels( - cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> Tuple[bool, bool]: + def get_default_logging_levels(cls, fx_name: str, on_step: bool | None, on_epoch: bool | None) -> tuple[bool, bool]: """Return default logging levels for given hook.""" fx_config = cls.functions[fx_name] assert fx_config is not None @@ -188,8 +186,8 @@ def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> No @classmethod def check_logging_and_get_default_levels( - cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> Tuple[bool, bool]: + cls, fx_name: str, on_step: bool | None, on_epoch: bool | None + ) -> tuple[bool, bool]: """Check if the given hook name is allowed to log and return logging levels.""" cls.check_logging(fx_name) on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index 773257f119f15..0bb17b751aeeb 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Optional, Union +from __future__ import annotations + +from typing import Any, Iterable from lightning_utilities.core.apply_func import apply_to_collection from lightning_utilities.core.rank_zero import WarningCache @@ -29,18 +31,18 @@ class _LoggerConnector: - def __init__(self, trainer: "pl.Trainer") -> None: + def __init__(self, trainer: pl.Trainer) -> None: self.trainer = trainer self._progress_bar_metrics: _PBAR_DICT = {} self._logged_metrics: _OUT_DICT = {} self._callback_metrics: _OUT_DICT = {} self._epoch_end_reached = False - self._current_fx: Optional[str] = None - self._batch_idx: Optional[int] = None + self._current_fx: str | None = None + self._batch_idx: int | None = None def on_trainer_init( self, - logger: Union[bool, Logger, Iterable[Logger]], + logger: bool | Logger | Iterable[Logger], log_every_n_steps: int, ) -> None: self.configure_logger(logger) @@ -55,7 +57,7 @@ def should_update_logs(self) -> bool: should_log = (trainer.fit_loop.epoch_loop._batches_that_stepped + 1) % trainer.log_every_n_steps == 0 return should_log or trainer.should_stop - def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> None: + def configure_logger(self, logger: bool | Logger | Iterable[Logger]) -> None: if not logger: # logger is None or logger is False self.trainer.loggers = [] @@ -78,7 +80,7 @@ def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> Non else: self.trainer.loggers = [logger] - def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: + def log_metrics(self, metrics: _OUT_DICT, step: int | None = None) -> None: """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step. @@ -168,7 +170,7 @@ def update_train_epoch_metrics(self) -> None: def on_epoch_start(self) -> None: self._epoch_end_reached = False - def on_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> None: + def on_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int | None = None) -> None: self._batch_idx = batch_idx self._epoch_end_reached = False diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 06ce4d021d9e1..fc1b84419e356 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -46,11 +48,11 @@ class _METRICS(TypedDict): @dataclass class _Sync: - fn: Optional[Callable] = None + fn: Callable | None = None _should: bool = False rank_zero_only: bool = False - _op: Optional[str] = None - _group: Optional[Any] = None + _op: str | None = None + _group: Any | None = None def __post_init__(self) -> None: self._generate_sync_fn() @@ -66,21 +68,21 @@ def should(self, should: bool) -> None: self._generate_sync_fn() @property - def op(self) -> Optional[str]: + def op(self) -> str | None: return self._op @op.setter - def op(self, op: Optional[str]) -> None: + def op(self, op: str | None) -> None: self._op = op # `self._fn` needs to be re-generated. self._generate_sync_fn() @property - def group(self) -> Optional[Any]: + def group(self) -> Any | None: return self._group @group.setter - def group(self, group: Optional[Any]) -> None: + def group(self, group: Any | None) -> None: self._group = group # `self._fn` needs to be re-generated. self._generate_sync_fn() @@ -113,9 +115,9 @@ class _Metadata: reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean # type: ignore[assignment] enable_graph: bool = False add_dataloader_idx: bool = True - dataloader_idx: Optional[int] = None - metric_attribute: Optional[str] = None - _sync: Optional[_Sync] = None + dataloader_idx: int | None = None + metric_attribute: str | None = None + _sync: _Sync | None = None def __post_init__(self) -> None: if not self.on_step and not self.on_epoch: @@ -200,7 +202,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: self.cumulated_batch_size: Tensor self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) # this is defined here only because upstream is missing the type annotation - self._forward_cache: Optional[Any] = None + self._forward_cache: Any | None = None def update(self, value: _VALUE, batch_size: int) -> None: if self.is_tensor: @@ -262,7 +264,7 @@ def forward(self, value: _VALUE, batch_size: int) -> None: def _wrap_compute(self, compute: Any) -> Any: # Override to avoid syncing - we handle it ourselves. @wraps(compute) - def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: + def wrapped_func(*args: Any, **kwargs: Any) -> Any | None: if not self._update_called: rank_zero_warn( f"The ``compute`` method of metric {self.__class__.__name__}" @@ -288,7 +290,7 @@ def __repr__(self) -> str: state += f", cumulated_batch_size={self.cumulated_batch_size}" return f"{self.__class__.__name__}({state})" - def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric": + def to(self, *args: Any, **kwargs: Any) -> _ResultMetric: d = self.__dict__ if _TORCH_GREATER_EQUAL_2_0: # https://github.com/pytorch/pytorch/issues/96198 d = dict(d) @@ -316,15 +318,15 @@ class _ResultCollection(dict): def __init__(self, training: bool) -> None: super().__init__() self.training = training - self.batch: Optional[Any] = None - self.batch_size: Optional[int] = None - self.dataloader_idx: Optional[int] = None + self.batch: Any | None = None + self.batch_size: int | None = None + self.dataloader_idx: int | None = None @property - def result_metrics(self) -> List[_ResultMetric]: + def result_metrics(self) -> list[_ResultMetric]: return list(self.values()) - def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], meta: _Metadata) -> int: + def _extract_batch_size(self, value: _ResultMetric, batch_size: int | None, meta: _Metadata) -> int: # check if we have extracted the batch size already if batch_size is None: batch_size = self.batch_size @@ -353,10 +355,10 @@ def log( enable_graph: bool = False, sync_dist: bool = False, sync_dist_fn: Callable = _Sync.no_op, - sync_dist_group: Optional[Any] = None, + sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, - batch_size: Optional[int] = None, - metric_attribute: Optional[str] = None, + batch_size: int | None = None, + metric_attribute: str | None = None, rank_zero_only: bool = False, ) -> None: """See :meth:`~lightning.pytorch.core.module.LightningModule.log`""" @@ -414,7 +416,7 @@ def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric.has_reset = False @staticmethod - def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: + def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Tensor | None: cache = None if on_step and result_metric.meta.on_step: cache = result_metric._forward_cache @@ -448,7 +450,7 @@ def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() if not v.has_reset and self.dataloader_idx == v.meta.dataloader_idx) - def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]: + def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> tuple[str, str]: name = result_metric.meta.name forked_name = result_metric.meta.forked_name(on_step) add_dataloader_idx = result_metric.meta.add_dataloader_idx @@ -485,7 +487,7 @@ def metrics(self, on_step: bool) -> _METRICS: return metrics - def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: + def reset(self, metrics: bool | None = None, fx: str | None = None) -> None: """Reset the result collection. Args: @@ -500,12 +502,12 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non if requested_type and same_fx: item.reset() - def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection": + def to(self, *args: Any, **kwargs: Any) -> _ResultCollection: """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs)) return self - def cpu(self) -> "_ResultCollection": + def cpu(self) -> _ResultCollection: """Move all data to CPU.""" return self.to(device="cpu") diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 7e6b7cd0c5e91..a481a29d83d41 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import signal @@ -5,7 +7,7 @@ import threading from subprocess import call from types import FrameType -from typing import Any, Callable, Dict, List, Set, Union +from typing import Any, Callable, Union from lightning_utilities.core.rank_zero import rank_prefixed_message @@ -22,7 +24,7 @@ class _HandlersCompose: - def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None: + def __init__(self, signal_handlers: list[_HANDLER] | _HANDLER) -> None: if not isinstance(signal_handlers, list): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers @@ -36,17 +38,17 @@ def __call__(self, signum: _SIGNUM, frame: FrameType) -> None: class _SignalConnector: - def __init__(self, trainer: "pl.Trainer") -> None: + def __init__(self, trainer: pl.Trainer) -> None: self.received_sigterm = False self.trainer = trainer - self._original_handlers: Dict[_SIGNUM, _HANDLER] = {} + self._original_handlers: dict[_SIGNUM, _HANDLER] = {} def register_signal_handlers(self) -> None: self.received_sigterm = False self._original_handlers = self._get_current_signal_handlers() - sigusr_handlers: List[_HANDLER] = [] - sigterm_handlers: List[_HANDLER] = [self._sigterm_notifier_fn] + sigusr_handlers: list[_HANDLER] = [] + sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn] environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: @@ -125,7 +127,7 @@ def teardown(self) -> None: self._original_handlers = {} @staticmethod - def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: + def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: """Collects the currently assigned signal handlers.""" valid_signals = _SignalConnector._valid_signals() if not _IS_WINDOWS: @@ -134,7 +136,7 @@ def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: return {signum: signal.getsignal(signum) for signum in valid_signals} @staticmethod - def _valid_signals() -> Set[signal.Signals]: + def _valid_signals() -> set[signal.Signals]: """Returns all valid signals supported on the current platform. Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for @@ -168,7 +170,7 @@ def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: if threading.current_thread() is threading.main_thread(): signal.signal(signum, handlers) # type: ignore[arg-type] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() state["_original_handlers"] = {} return state diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 36f9c27e70983..9619f8f712595 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -13,7 +13,7 @@ # limitations under the License. """Houses the methods used to set up the Trainer.""" -from typing import Optional, Union +from __future__ import annotations import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -33,14 +33,14 @@ def _init_debugging_flags( - trainer: "pl.Trainer", - limit_train_batches: Optional[Union[int, float]], - limit_val_batches: Optional[Union[int, float]], - limit_test_batches: Optional[Union[int, float]], - limit_predict_batches: Optional[Union[int, float]], - fast_dev_run: Union[int, bool], - overfit_batches: Union[int, float], - val_check_interval: Optional[Union[int, float]], + trainer: pl.Trainer, + limit_train_batches: int | float | None, + limit_val_batches: int | float | None, + limit_test_batches: int | float | None, + limit_predict_batches: int | float | None, + fast_dev_run: int | bool, + overfit_batches: int | float, + val_check_interval: int | float | None, num_sanity_val_steps: int, ) -> None: # init debugging flags @@ -89,7 +89,7 @@ def _init_debugging_flags( trainer.limit_val_batches = overfit_batches -def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]: +def _determine_batch_limits(batches: int | float | None, name: str) -> int | float: if batches is None: # batches is optional to know if the user passed a value so that we can show the above info messages only to the # users that set a value explicitly @@ -122,7 +122,7 @@ def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> ) -def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str]]) -> None: +def _init_profiler(trainer: pl.Trainer, profiler: Profiler | str | None) -> None: if isinstance(profiler, str): PROFILERS = { "simple": SimpleProfiler, @@ -141,7 +141,7 @@ def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str trainer.profiler = profiler or PassThroughProfiler() -def _log_device_info(trainer: "pl.Trainer") -> None: +def _log_device_info(trainer: pl.Trainer) -> None: if CUDAAccelerator.is_available(): gpu_available = True gpu_type = " (cuda)" diff --git a/src/lightning/pytorch/trainer/states.py b/src/lightning/pytorch/trainer/states.py index 73b7cb71dcf82..c2096744c17a7 100644 --- a/src/lightning/pytorch/trainer/states.py +++ b/src/lightning/pytorch/trainer/states.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from dataclasses import dataclass -from typing import Optional from lightning.pytorch.utilities import LightningEnum @@ -66,7 +67,7 @@ def evaluating(self) -> bool: return self in (self.VALIDATING, self.TESTING, self.SANITY_CHECKING) @property - def dataloader_prefix(self) -> Optional[str]: + def dataloader_prefix(self) -> str | None: if self in (self.VALIDATING, self.SANITY_CHECKING): return "val" return self.value @@ -77,8 +78,8 @@ class TrainerState: """Dataclass to encapsulate the current :class:`~lightning.pytorch.trainer.trainer.Trainer` state.""" status: TrainerStatus = TrainerStatus.INITIALIZING - fn: Optional[TrainerFn] = None - stage: Optional[RunningStage] = None + fn: TrainerFn | None = None + stage: RunningStage | None = None @property def finished(self) -> bool: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 74545a15acf44..bf44196629724 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -19,12 +19,14 @@ # DO NOT REMOVE THIS NOTICE # - WILLIAM FALCON """Trainer to automate the training.""" +from __future__ import annotations + import logging import math import os import warnings from datetime import timedelta -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Iterable from weakref import proxy import torch @@ -89,45 +91,45 @@ class Trainer: def __init__( self, *, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + accelerator: str | Accelerator = "auto", + strategy: str | Strategy = "auto", + devices: list[int] | str | int = "auto", num_nodes: int = 1, precision: _PRECISION_INPUT = "32-true", - logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, - callbacks: Optional[Union[List[Callback], Callback]] = None, - fast_dev_run: Union[int, bool] = False, - max_epochs: Optional[int] = None, - min_epochs: Optional[int] = None, + logger: Logger | Iterable[Logger] | bool | None = None, + callbacks: list[Callback] | Callback | None = None, + fast_dev_run: int | bool = False, + max_epochs: int | None = None, + min_epochs: int | None = None, max_steps: int = -1, - min_steps: Optional[int] = None, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, - limit_train_batches: Optional[Union[int, float]] = None, - limit_val_batches: Optional[Union[int, float]] = None, - limit_test_batches: Optional[Union[int, float]] = None, - limit_predict_batches: Optional[Union[int, float]] = None, - overfit_batches: Union[int, float] = 0.0, - val_check_interval: Optional[Union[int, float]] = None, - check_val_every_n_epoch: Optional[int] = 1, - num_sanity_val_steps: Optional[int] = None, - log_every_n_steps: Optional[int] = None, - enable_checkpointing: Optional[bool] = None, - enable_progress_bar: Optional[bool] = None, - enable_model_summary: Optional[bool] = None, + min_steps: int | None = None, + max_time: str | timedelta | dict[str, int] | None = None, + limit_train_batches: int | float | None = None, + limit_val_batches: int | float | None = None, + limit_test_batches: int | float | None = None, + limit_predict_batches: int | float | None = None, + overfit_batches: int | float = 0.0, + val_check_interval: int | float | None = None, + check_val_every_n_epoch: int | None = 1, + num_sanity_val_steps: int | None = None, + log_every_n_steps: int | None = None, + enable_checkpointing: bool | None = None, + enable_progress_bar: bool | None = None, + enable_model_summary: bool | None = None, accumulate_grad_batches: int = 1, - gradient_clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[str] = None, - deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, - benchmark: Optional[bool] = None, + gradient_clip_val: int | float | None = None, + gradient_clip_algorithm: str | None = None, + deterministic: bool | _LITERAL_WARN | None = None, + benchmark: bool | None = None, inference_mode: bool = True, use_distributed_sampler: bool = True, - profiler: Optional[Union[Profiler, str]] = None, + profiler: Profiler | str | None = None, detect_anomaly: bool = False, barebones: bool = False, - plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, + plugins: PLUGIN_INPUT | list[PLUGIN_INPUT] | None = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, - default_root_dir: Optional[_PATH] = None, + default_root_dir: _PATH | None = None, ) -> None: r"""Customize every aspect of training via flags. @@ -435,7 +437,7 @@ def __init__( ) # init data flags - self.check_val_every_n_epoch: Optional[int] + self.check_val_every_n_epoch: int | None self._data_connector.on_trainer_init( val_check_interval, reload_dataloaders_every_n_epochs, @@ -454,8 +456,8 @@ def __init__( f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}." ) - self.gradient_clip_val: Optional[Union[int, float]] = gradient_clip_val - self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = ( + self.gradient_clip_val: int | float | None = gradient_clip_val + self.gradient_clip_algorithm: GradClipAlgorithmType | None = ( GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None ) @@ -475,17 +477,17 @@ def __init__( setup._init_profiler(self, profiler) # init logger flags - self._loggers: List[Logger] + self._loggers: list[Logger] self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags - self.val_check_batch: Union[int, float] - self.val_check_interval: Union[int, float] - self.num_sanity_val_steps: Union[int, float] - self.limit_train_batches: Union[int, float] - self.limit_val_batches: Union[int, float] - self.limit_test_batches: Union[int, float] - self.limit_predict_batches: Union[int, float] + self.val_check_batch: int | float + self.val_check_interval: int | float + self.num_sanity_val_steps: int | float + self.limit_train_batches: int | float + self.limit_val_batches: int | float + self.limit_test_batches: int | float + self.limit_predict_batches: int | float setup._init_debugging_flags( self, limit_train_batches, @@ -500,11 +502,11 @@ def __init__( def fit( self, - model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - ckpt_path: Optional[str] = None, + model: pl.LightningModule, + train_dataloaders: TRAIN_DATALOADERS | LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: LightningDataModule | None = None, + ckpt_path: str | None = None, ) -> None: r"""Runs the full optimization routine. @@ -541,11 +543,11 @@ def fit( def _fit_impl( self, - model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - ckpt_path: Optional[str] = None, + model: pl.LightningModule, + train_dataloaders: TRAIN_DATALOADERS | LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: LightningDataModule | None = None, + ckpt_path: str | None = None, ) -> None: log.debug(f"{self.__class__.__name__}: trainer fit stage") @@ -582,11 +584,11 @@ def _fit_impl( def validate( self, - model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = None, + model: pl.LightningModule | None = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: str | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, + datamodule: LightningDataModule | None = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the validation set. @@ -640,12 +642,12 @@ def validate( def _validate_impl( self, - model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = None, + model: pl.LightningModule | None = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: str | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: + datamodule: LightningDataModule | None = None, + ) -> _PREDICT_OUTPUT | _EVALUATE_OUTPUT | None: # -------------------- # SETUP HOOK # -------------------- @@ -688,11 +690,11 @@ def _validate_impl( def test( self, - model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = None, + model: pl.LightningModule | None = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: str | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, + datamodule: LightningDataModule | None = None, ) -> _EVALUATE_OUTPUT: r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. @@ -747,12 +749,12 @@ def test( def _test_impl( self, - model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - ckpt_path: Optional[str] = None, + model: pl.LightningModule | None = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + ckpt_path: str | None = None, verbose: bool = True, - datamodule: Optional[LightningDataModule] = None, - ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: + datamodule: LightningDataModule | None = None, + ) -> _PREDICT_OUTPUT | _EVALUATE_OUTPUT | None: # -------------------- # SETUP HOOK # -------------------- @@ -795,12 +797,12 @@ def _test_impl( def predict( self, - model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - datamodule: Optional[LightningDataModule] = None, - return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = None, - ) -> Optional[_PREDICT_OUTPUT]: + model: pl.LightningModule | None = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + datamodule: LightningDataModule | None = None, + return_predictions: bool | None = None, + ckpt_path: str | None = None, + ) -> _PREDICT_OUTPUT | None: r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. @@ -855,12 +857,12 @@ def predict( def _predict_impl( self, - model: Optional["pl.LightningModule"] = None, - dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None, - datamodule: Optional[LightningDataModule] = None, - return_predictions: Optional[bool] = None, - ckpt_path: Optional[str] = None, - ) -> Optional[_PREDICT_OUTPUT]: + model: pl.LightningModule | None = None, + dataloaders: EVAL_DATALOADERS | LightningDataModule | None = None, + datamodule: LightningDataModule | None = None, + return_predictions: bool | None = None, + ckpt_path: str | None = None, + ) -> _PREDICT_OUTPUT | None: # -------------------- # SETUP HOOK # -------------------- @@ -899,8 +901,8 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None - ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: + self, model: pl.LightningModule, ckpt_path: _PATH | None = None + ) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None: _verify_strategy_supports_compile(model, self.strategy) if self.state.fn == TrainerFn.FITTING: @@ -1010,7 +1012,7 @@ def _teardown(self) -> None: self._logger_connector.teardown() self._signal_connector.teardown() - def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: + def _run_stage(self) -> _PREDICT_OUTPUT | _EVALUATE_OUTPUT | None: # wait for all to join if on distributed self.strategy.barrier("run-stage") @@ -1111,7 +1113,7 @@ def num_nodes(self) -> int: return getattr(self.strategy, "num_nodes", 1) @property - def device_ids(self) -> List[int]: + def device_ids(self) -> list[int]: """List of device indexes per node.""" devices = ( self.strategy.parallel_devices @@ -1133,20 +1135,20 @@ def num_devices(self) -> int: return len(self.device_ids) @property - def lightning_module(self) -> "pl.LightningModule": + def lightning_module(self) -> pl.LightningModule: # TODO: this is actually an optional return return self.strategy.lightning_module # type: ignore[return-value] @property - def optimizers(self) -> List[Optimizer]: + def optimizers(self) -> list[Optimizer]: return self.strategy.optimizers @optimizers.setter - def optimizers(self, new_optims: List[Optimizer]) -> None: + def optimizers(self, new_optims: list[Optimizer]) -> None: self.strategy.optimizers = new_optims @property - def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: + def lr_scheduler_configs(self) -> list[LRSchedulerConfig]: return self.strategy.lr_scheduler_configs @property @@ -1154,11 +1156,11 @@ def precision(self) -> _PRECISION_INPUT_STR: return self.strategy.precision_plugin.precision @property - def scaler(self) -> Optional[Any]: + def scaler(self) -> Any | None: return getattr(self.precision_plugin, "scaler", None) @property - def model(self) -> Optional[torch.nn.Module]: + def model(self) -> torch.nn.Module | None: """The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. To access the pure LightningModule, use @@ -1171,7 +1173,7 @@ def model(self) -> Optional[torch.nn.Module]: """ @property - def log_dir(self) -> Optional[str]: + def log_dir(self) -> str | None: """The directory for the current experiment. Use this to save images to, etc... .. code-block:: python @@ -1204,7 +1206,7 @@ def training_step(self, batch, batch_idx): return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: + def distributed_sampler_kwargs(self) -> dict[str, Any] | None: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs return None @@ -1229,33 +1231,33 @@ def default_root_dir(self) -> str: return self._default_root_dir @property - def early_stopping_callback(self) -> Optional[EarlyStopping]: + def early_stopping_callback(self) -> EarlyStopping | None: """The first :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.early_stopping_callbacks return callbacks[0] if len(callbacks) > 0 else None @property - def early_stopping_callbacks(self) -> List[EarlyStopping]: + def early_stopping_callbacks(self) -> list[EarlyStopping]: """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @property - def checkpoint_callback(self) -> Optional[Checkpoint]: + def checkpoint_callback(self) -> Checkpoint | None: """The first :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> List[Checkpoint]: + def checkpoint_callbacks(self) -> list[Checkpoint]: """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, Checkpoint)] @property - def progress_bar_callback(self) -> Optional[ProgressBar]: + def progress_bar_callback(self) -> ProgressBar | None: """An instance of :class:`~lightning.pytorch.callbacks.progress.progress_bar.ProgressBar` found in the Trainer.callbacks list, or ``None`` if one doesn't exist.""" for c in self.callbacks: @@ -1264,7 +1266,7 @@ def progress_bar_callback(self) -> Optional[ProgressBar]: return None @property - def ckpt_path(self) -> Optional[_PATH]: + def ckpt_path(self) -> _PATH | None: """Set to the path/URL of a checkpoint loaded via :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`, or @@ -1272,7 +1274,7 @@ def ckpt_path(self) -> Optional[_PATH]: return self._checkpoint_connector._ckpt_path @ckpt_path.setter - def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: + def ckpt_path(self, ckpt_path: _PATH | None) -> None: """Allows you to manage which checkpoint is loaded statefully. .. code-block:: python @@ -1289,9 +1291,7 @@ def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: self._checkpoint_connector._ckpt_path = ckpt_path self._checkpoint_connector._user_managed = bool(ckpt_path) - def save_checkpoint( - self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None - ) -> None: + def save_checkpoint(self, filepath: _PATH, weights_only: bool = False, storage_options: Any | None = None) -> None: r"""Runs routine to create a checkpoint. Args: @@ -1409,11 +1409,11 @@ def current_epoch(self) -> int: return self.fit_loop.epoch_progress.current.completed @property - def max_epochs(self) -> Optional[int]: + def max_epochs(self) -> int | None: return self.fit_loop.max_epochs @property - def min_epochs(self) -> Optional[int]: + def min_epochs(self) -> int | None: return self.fit_loop.min_epochs @property @@ -1421,7 +1421,7 @@ def max_steps(self) -> int: return self.fit_loop.max_steps @property - def min_steps(self) -> Optional[int]: + def min_steps(self) -> int | None: return self.fit_loop.min_steps @property @@ -1430,14 +1430,14 @@ def is_last_batch(self) -> bool: return self.fit_loop.epoch_loop.batch_progress.is_last_batch @property - def train_dataloader(self) -> Optional[TRAIN_DATALOADERS]: + def train_dataloader(self) -> TRAIN_DATALOADERS | None: """The training dataloader(s) used during ``trainer.fit()``.""" if (combined_loader := self.fit_loop._combined_loader) is not None: return combined_loader.iterables return None @property - def val_dataloaders(self) -> Optional[EVAL_DATALOADERS]: + def val_dataloaders(self) -> EVAL_DATALOADERS | None: """The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``.""" if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None or ( combined_loader := self.validate_loop._combined_loader @@ -1446,26 +1446,26 @@ def val_dataloaders(self) -> Optional[EVAL_DATALOADERS]: return None @property - def test_dataloaders(self) -> Optional[EVAL_DATALOADERS]: + def test_dataloaders(self) -> EVAL_DATALOADERS | None: """The test dataloader(s) used during ``trainer.test()``.""" if (combined_loader := self.test_loop._combined_loader) is not None: return combined_loader.iterables return None @property - def predict_dataloaders(self) -> Optional[EVAL_DATALOADERS]: + def predict_dataloaders(self) -> EVAL_DATALOADERS | None: """The prediction dataloader(s) used during ``trainer.predict()``.""" if (combined_loader := self.predict_loop._combined_loader) is not None: return combined_loader.iterables return None @property - def num_training_batches(self) -> Union[int, float]: + def num_training_batches(self) -> int | float: """The number of training batches that will be used during ``trainer.fit()``.""" return self.fit_loop.max_batches @property - def num_sanity_val_batches(self) -> Union[int, float, List[Union[int, float]]]: + def num_sanity_val_batches(self) -> int | float | list[int | float]: """The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches @@ -1476,7 +1476,7 @@ def num_sanity_val_batches(self) -> Union[int, float, List[Union[int, float]]]: return min(sanity_val_steps, max_batches) @property - def num_val_batches(self) -> Union[int, float, List[Union[int, float]]]: + def num_val_batches(self) -> int | float | list[int | float]: """The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: @@ -1486,12 +1486,12 @@ def num_val_batches(self) -> Union[int, float, List[Union[int, float]]]: return self.fit_loop.epoch_loop.val_loop._max_batches @property - def num_test_batches(self) -> Union[int, float, List[Union[int, float]]]: + def num_test_batches(self) -> int | float | list[int | float]: """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches @property - def num_predict_batches(self) -> List[Union[int, float]]: + def num_predict_batches(self) -> list[int | float]: """The number of prediction batches that will be used during ``trainer.predict()``.""" return self.predict_loop.max_batches @@ -1506,7 +1506,7 @@ def _evaluation_loop(self) -> _EvaluationLoop: raise RuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope") @property - def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionLoop]]: + def _active_loop(self) -> _FitLoop | _EvaluationLoop | _PredictionLoop | None: if self.training: return self.fit_loop if self.sanity_checking or self.evaluating: @@ -1520,19 +1520,19 @@ def _active_loop(self) -> Optional[Union[_FitLoop, _EvaluationLoop, _PredictionL """ @property - def logger(self) -> Optional[Logger]: + def logger(self) -> Logger | None: """The first :class:`~lightning.pytorch.loggers.logger.Logger` being used.""" return self.loggers[0] if len(self.loggers) > 0 else None @logger.setter - def logger(self, logger: Optional[Logger]) -> None: + def logger(self, logger: Logger | None) -> None: if not logger: self.loggers = [] else: self.loggers = [logger] @property - def loggers(self) -> List[Logger]: + def loggers(self) -> list[Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1543,7 +1543,7 @@ def loggers(self) -> List[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[List[Logger]]) -> None: + def loggers(self, loggers: list[Logger] | None) -> None: self._loggers = loggers if loggers else [] @property @@ -1580,7 +1580,7 @@ def progress_bar_metrics(self) -> _PBAR_DICT: return self._logger_connector.progress_bar_metrics @property - def _results(self) -> Optional[_ResultCollection]: + def _results(self) -> _ResultCollection | None: active_loop = self._active_loop if active_loop is not None: return active_loop._results @@ -1591,7 +1591,7 @@ def _results(self) -> Optional[_ResultCollection]: """ @property - def estimated_stepping_batches(self) -> Union[int, float]: + def estimated_stepping_batches(self) -> int | float: r""" The estimated number of batches that will ``optimizer.step()`` during training. diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 2b3ec7ef7f33a..d5ab7d8a9ca87 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +from __future__ import annotations + import logging import os import uuid from copy import deepcopy -from typing import Any, Dict, Optional, Tuple +from typing import Any import lightning.pytorch as pl from lightning.pytorch.utilities.memory import garbage_collection_cuda, is_oom_error @@ -26,13 +28,13 @@ def _scale_batch_size( - trainer: "pl.Trainer", + trainer: pl.Trainer, mode: str = "power", steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", -) -> Optional[int]: +) -> int | None: """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -97,7 +99,7 @@ def _scale_batch_size( return new_size -def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: +def __scale_batch_dump_params(trainer: pl.Trainer) -> dict[str, Any]: dumped_params = { "loggers": trainer.loggers, "callbacks": trainer.callbacks, @@ -118,7 +120,7 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: return dumped_params -def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None: +def __scale_batch_reset_params(trainer: pl.Trainer, steps_per_trial: int) -> None: from lightning.pytorch.loggers.logger import DummyLogger trainer.logger = DummyLogger() if trainer.logger is not None else None @@ -137,7 +139,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N loop.verbose = False -def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def __scale_batch_restore_params(trainer: pl.Trainer, params: dict[str, Any]) -> None: # TODO: There are more states that needs to be reset (#4512 and #4870) trainer.loggers = params["loggers"] trainer.callbacks = params["callbacks"] @@ -164,11 +166,11 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) def _run_power_scaling( - trainer: "pl.Trainer", + trainer: pl.Trainer, new_size: int, batch_arg_name: str, max_trials: int, - params: Dict[str, Any], + params: dict[str, Any], ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -206,11 +208,11 @@ def _run_power_scaling( def _run_binary_scaling( - trainer: "pl.Trainer", + trainer: pl.Trainer, new_size: int, batch_arg_name: str, max_trials: int, - params: Dict[str, Any], + params: dict[str, Any], ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. @@ -270,12 +272,12 @@ def _run_binary_scaling( def _adjust_batch_size( - trainer: "pl.Trainer", + trainer: pl.Trainer, batch_arg_name: str = "batch_size", factor: float = 1.0, - value: Optional[int] = None, - desc: Optional[str] = None, -) -> Tuple[int, bool]: + value: int | None = None, + desc: str | None = None, +) -> tuple[int, bool]: """Helper function for adjusting the batch size. Args: @@ -316,14 +318,14 @@ def _adjust_batch_size( return new_size, changed -def _reset_dataloaders(trainer: "pl.Trainer") -> None: +def _reset_dataloaders(trainer: pl.Trainer) -> None: loop = trainer._active_loop assert loop is not None loop._combined_loader = None # force a reload loop.setup_data() -def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def _try_loop_run(trainer: pl.Trainer, params: dict[str, Any]) -> None: loop = trainer._active_loop assert loop is not None loop.load_state_dict(deepcopy(params["loop_state_dict"])) @@ -331,7 +333,7 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: loop.run() -def _reset_progress(trainer: "pl.Trainer") -> None: +def _reset_progress(trainer: pl.Trainer) -> None: if trainer.lightning_module.automatic_optimization: trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset() else: diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 745f9e51bfaff..edd1d25b007ce 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import importlib import logging import os import uuid from copy import deepcopy -from typing import Any, cast, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, cast, TYPE_CHECKING import torch from lightning_utilities.core.imports import RequirementCache @@ -43,7 +45,7 @@ log = logging.getLogger(__name__) -def _determine_lr_attr_name(model: "pl.LightningModule", attr_name: str = "") -> str: +def _determine_lr_attr_name(model: pl.LightningModule, attr_name: str = "") -> str: if attr_name: if not lightning_hasattr(model, attr_name): raise AttributeError( @@ -98,10 +100,10 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) - self.lr_max = lr_max self.num_training = num_training - self.results: Dict[str, Any] = {} + self.results: dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose - def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: + def _exchange_scheduler(self, trainer: pl.Trainer) -> None: # TODO: update docs here """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified optimizer together with a new scheduler that takes care of the learning rate search.""" @@ -130,7 +132,7 @@ def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step")] _validate_optimizers_attached(trainer.optimizers, trainer.lr_scheduler_configs) - def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = None) -> Optional["plt.Figure"]: + def plot(self, suggest: bool = False, show: bool = False, ax: Axes | None = None) -> plt.Figure | None: """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point @@ -171,7 +173,7 @@ def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = return fig - def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: + def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> float | None: """This will propose a suggestion for an initial learning rate based on the point with the steepest negative gradient. @@ -205,16 +207,16 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] def _lr_find( - trainer: "pl.Trainer", - model: "pl.LightningModule", + trainer: pl.Trainer, + model: pl.LightningModule, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = "exponential", - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, update_attr: bool = False, attr_name: str = "", -) -> Optional[_LRFinder]: +) -> _LRFinder | None: """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -302,7 +304,7 @@ def _lr_find( return lr_finder -def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: +def __lr_finder_dump_params(trainer: pl.Trainer) -> dict[str, Any]: return { "optimizers": trainer.strategy.optimizers, "lr_scheduler_configs": trainer.strategy.lr_scheduler_configs, @@ -314,7 +316,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: } -def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None: +def __lr_finder_reset_params(trainer: pl.Trainer, num_training: int, early_stop_threshold: float | None) -> None: from lightning.pytorch.loggers.logger import DummyLogger trainer.strategy.lr_scheduler_configs = [] @@ -327,7 +329,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.limit_val_batches = num_training -def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def __lr_finder_restore_params(trainer: pl.Trainer, params: dict[str, Any]) -> None: trainer.strategy.optimizers = params["optimizers"] trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"] trainer.callbacks = params["callbacks"] @@ -360,22 +362,22 @@ class _LRCallback(Callback): def __init__( self, num_training: int, - early_stop_threshold: Optional[float] = 4.0, + early_stop_threshold: float | None = 4.0, progress_bar_refresh_rate: int = 0, beta: float = 0.98, ): self.num_training = num_training self.early_stop_threshold = early_stop_threshold self.beta = beta - self.losses: List[float] = [] - self.lrs: List[float] = [] + self.losses: list[float] = [] + self.lrs: list[float] = [] self.avg_loss = 0.0 self.best_loss = 0.0 self.progress_bar_refresh_rate = progress_bar_refresh_rate self.progress_bar = None def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int ) -> None: """Called before each training batch, logs the lr that will be used.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: @@ -387,7 +389,7 @@ def on_train_batch_start( self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr] def on_train_batch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: """Called when the training batch ends, logs the calculated loss.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: @@ -443,7 +445,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -455,7 +457,7 @@ def get_lr(self) -> List[float]: return val @property - def lr(self) -> Union[float, List[float]]: + def lr(self) -> float | list[float]: return self._lr @@ -478,7 +480,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in self.num_iter = num_iter super().__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -490,11 +492,11 @@ def get_lr(self) -> List[float]: return val @property - def lr(self) -> Union[float, List[float]]: + def lr(self) -> float | list[float]: return self._lr -def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def _try_loop_run(trainer: pl.Trainer, params: dict[str, Any]) -> None: loop = trainer.fit_loop loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index 64065b9576faa..0d66fe2e59815 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Union +from __future__ import annotations + +from typing import Literal import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback @@ -22,23 +24,23 @@ class Tuner: """Tuner class to tune your model.""" - def __init__(self, trainer: "pl.Trainer") -> None: + def __init__(self, trainer: pl.Trainer) -> None: self._trainer = trainer def scale_batch_size( self, - model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional["pl.LightningDataModule"] = None, + model: pl.LightningModule, + train_dataloaders: TRAIN_DATALOADERS | pl.LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + dataloaders: EVAL_DATALOADERS | None = None, + datamodule: pl.LightningDataModule | None = None, method: Literal["fit", "validate", "test", "predict"] = "fit", mode: str = "power", steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", - ) -> Optional[int]: + ) -> int | None: """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -103,11 +105,11 @@ def scale_batch_size( def lr_find( self, - model: "pl.LightningModule", - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional["pl.LightningDataModule"] = None, + model: pl.LightningModule, + train_dataloaders: TRAIN_DATALOADERS | pl.LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + dataloaders: EVAL_DATALOADERS | None = None, + datamodule: pl.LightningDataModule | None = None, method: Literal["fit", "validate", "test", "predict"] = "fit", min_lr: float = 1e-8, max_lr: float = 1, @@ -116,7 +118,7 @@ def lr_find( early_stop_threshold: float = 4.0, update_attr: bool = True, attr_name: str = "", - ) -> Optional["pl.tuner.lr_finder._LRFinder"]: + ) -> pl.tuner.lr_finder._LRFinder | None: """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -180,9 +182,9 @@ def lr_find( def _check_tuner_configuration( - train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - dataloaders: Optional[EVAL_DATALOADERS] = None, + train_dataloaders: TRAIN_DATALOADERS | pl.LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + dataloaders: EVAL_DATALOADERS | None = None, method: Literal["fit", "validate", "test", "predict"] = "fit", ) -> None: supported_methods = ("fit", "validate", "test", "predict") @@ -203,7 +205,7 @@ def _check_tuner_configuration( ) -def _check_lr_find_configuration(trainer: "pl.Trainer") -> None: +def _check_lr_find_configuration(trainer: pl.Trainer) -> None: # local import to avoid circular import from lightning.pytorch.callbacks.lr_finder import LearningRateFinder @@ -215,7 +217,7 @@ def _check_lr_find_configuration(trainer: "pl.Trainer") -> None: ) -def _check_scale_batch_size_configuration(trainer: "pl.Trainer") -> None: +def _check_scale_batch_size_configuration(trainer: pl.Trainer) -> None: if trainer._accelerator_connector.is_distributed: raise ValueError("Tuning the batch size is currently not supported with distributed strategies.") diff --git a/src/lightning/pytorch/utilities/_pytree.py b/src/lightning/pytorch/utilities/_pytree.py index 1604f37b4cb9e..199759761955b 100644 --- a/src/lightning/pytorch/utilities/_pytree.py +++ b/src/lightning/pytorch/utilities/_pytree.py @@ -1,4 +1,6 @@ -from typing import Any, List, Tuple +from __future__ import annotations + +from typing import Any from torch.utils._pytree import _get_node_type, LeafSpec, PyTree, SUPPORTED_NODES, tree_unflatten, TreeSpec @@ -15,7 +17,7 @@ def _is_leaf_or_primitive_container(pytree: PyTree) -> bool: return all(isinstance(child, (int, float, str)) for child in child_pytrees) -def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: +def _tree_flatten(pytree: PyTree) -> tuple[list[Any], TreeSpec]: """Copy of :func:`torch.utils._pytree.tree_flatten` using our custom leaf function.""" if _is_leaf_or_primitive_container(pytree): return [pytree], LeafSpec() @@ -24,8 +26,8 @@ def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(pytree) - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] + result: list[Any] = [] + children_specs: list[TreeSpec] = [] for child in child_pytrees: flat, child_spec = _tree_flatten(child) result += flat @@ -34,6 +36,6 @@ def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: return result, TreeSpec(node_type, context, children_specs) -def _map_and_unflatten(fn: Any, values: List[Any], spec: TreeSpec) -> PyTree: +def _map_and_unflatten(fn: Any, values: list[Any], spec: TreeSpec) -> PyTree: """Utility function to apply a function and unflatten it.""" return tree_unflatten([fn(i) for i in values], spec) diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index 888a3b3755e1e..20f83f50e64b6 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -13,18 +13,20 @@ # limitations under the License. """Utilities for Argument Parsing within Lightning Components.""" +from __future__ import annotations + import inspect import os from argparse import Namespace from ast import literal_eval from contextlib import suppress from functools import wraps -from typing import Any, Callable, cast, Type, TypeVar +from typing import Any, Callable, cast, TypeVar _T = TypeVar("_T", bound=Callable[..., Any]) -def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def _parse_env_variables(cls: type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Examples: diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 0e012dbae145b..d1096f85e03be 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import contextlib from collections.abc import Iterable -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Iterator, List, Literal, Tuple, TypeVar from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict @@ -25,9 +27,9 @@ class _ModeIterator(Iterator[_T]): - def __init__(self, iterables: List[Iterable]) -> None: + def __init__(self, iterables: list[Iterable]) -> None: self.iterables = iterables - self.iterators: List[Iterator] = [] + self.iterators: list[Iterator] = [] def __next__(self) -> _T: raise NotImplementedError @@ -39,7 +41,7 @@ def __iter__(self) -> Self: def reset(self) -> None: self.iterators = [] - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # workaround an inconvenient `NotImplementedError`: @@ -53,11 +55,11 @@ def __getstate__(self) -> Dict[str, Any]: class _MaxSizeCycle(_ModeIterator[List]): - def __init__(self, iterables: List[Iterable]) -> None: + def __init__(self, iterables: list[Iterable]) -> None: super().__init__(iterables) - self._consumed: List[bool] = [] + self._consumed: list[bool] = [] - def __next__(self) -> List: + def __next__(self) -> list: n = len(self.iterators) out = [None] * n # values per iterator for i in range(n): @@ -83,31 +85,31 @@ def reset(self) -> None: class _MinSize(_ModeIterator[List]): - def __next__(self) -> List: + def __next__(self) -> list: return [next(it) for it in self.iterators] class _Sequential(_ModeIterator[Tuple[Any, int, int]]): - def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: list[int | float] | None = None) -> None: super().__init__(iterables) self._iterator_idx = 0 # what would be dataloader_idx self._idx = 0 # what would be batch_idx self.limits = limits @property - def limits(self) -> Optional[List[Union[int, float]]]: + def limits(self) -> list[int | float] | None: """Optional limits per iterator.""" return self._limits @limits.setter - def limits(self, limits: Optional[List[Union[int, float]]]) -> None: + def limits(self, limits: list[int | float] | None) -> None: if limits is not None and len(limits) != len(self.iterables): raise ValueError( f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.iterables)})" ) self._limits = limits - def __next__(self) -> Tuple[Any, int, int]: + def __next__(self) -> tuple[Any, int, int]: n = len(self.iterables) if n == 0 or self._iterator_idx >= n: raise StopIteration @@ -156,7 +158,7 @@ def _use_next_iterator(self) -> None: class _MaxSize(_ModeIterator[List]): - def __next__(self) -> List: + def __next__(self) -> list: n = len(self.iterators) out = [None] * n all_exhausted = True @@ -170,8 +172,8 @@ def __next__(self) -> List: class _CombinationMode(TypedDict): - fn: Callable[[List[int]], int] - iterator: Type[_ModeIterator] + fn: Callable[[list[int]], int] + iterator: type[_ModeIterator] _SUPPORTED_MODES = { @@ -246,7 +248,7 @@ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") self._iterables = iterables self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode - self._iterator: Optional[_ModeIterator] = None + self._iterator: _ModeIterator | None = None @property def iterables(self) -> Any: @@ -264,12 +266,12 @@ def batch_sampler(self) -> Any: return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self.flattened, self._spec) @property - def flattened(self) -> List[Any]: + def flattened(self) -> list[Any]: """Return the flat list of iterables.""" return self._flattened @flattened.setter - def flattened(self, flattened: List[Any]) -> None: + def flattened(self, flattened: list[Any]) -> None: """Setter to conveniently update the list of iterables.""" if len(flattened) != len(self._flattened): raise ValueError( diff --git a/src/lightning/pytorch/utilities/compile.py b/src/lightning/pytorch/utilities/compile.py index 67d9808a66222..906b810206372 100644 --- a/src/lightning/pytorch/utilities/compile.py +++ b/src/lightning/pytorch/utilities/compile.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations import torch @@ -21,7 +21,7 @@ from lightning.pytorch.utilities.model_helpers import _check_mixed_imports -def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule": +def from_compiled(model: torch._dynamo.OptimizedModule) -> pl.LightningModule: """Returns an instance LightningModule from the output of ``torch.compile``. .. warning:: This is an :ref:`experimental ` feature. @@ -70,7 +70,7 @@ def from_compiled(model: "torch._dynamo.OptimizedModule") -> "pl.LightningModule return orig_module -def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedModule"]) -> "pl.LightningModule": +def to_uncompiled(model: pl.LightningModule | torch._dynamo.OptimizedModule) -> pl.LightningModule: """Returns an instance of LightningModule without any compilation optimizations from a compiled model. .. warning:: This is an :ref:`experimental ` feature. @@ -113,7 +113,7 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod return model -def _maybe_unwrap_optimized(model: object) -> "pl.LightningModule": +def _maybe_unwrap_optimized(model: object) -> pl.LightningModule: if not _TORCH_GREATER_EQUAL_2_0: if not isinstance(model, pl.LightningModule): _check_mixed_imports(model) @@ -131,7 +131,7 @@ def _maybe_unwrap_optimized(model: object) -> "pl.LightningModule": ) -def _verify_strategy_supports_compile(model: "pl.LightningModule", strategy: Strategy) -> None: +def _verify_strategy_supports_compile(model: pl.LightningModule, strategy: Strategy) -> None: if model._compiler_ctx is not None: supported_strategies = (SingleDeviceStrategy, DDPStrategy, FSDPStrategy) if not isinstance(strategy, supported_strategies): diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index ce69cdc4d664a..4070cdea0de8b 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect from dataclasses import fields -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union +from typing import Any, Generator, Iterable, Mapping, Sized, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance @@ -38,7 +40,7 @@ warning_cache = WarningCache() -def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: +def _extract_batch_size(batch: BType) -> Generator[int | None, None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: yield 1 @@ -89,7 +91,7 @@ def extract_batch_size(batch: BType) -> int: def has_len_all_ranks( dataloader: object, - strategy: "pl.strategies.Strategy", + strategy: pl.strategies.Strategy, allow_zero_length_dataloader_with_multiple_devices: bool = False, ) -> TypeGuard[Sized]: """Checks if a given object has ``__len__`` method implemented on all ranks.""" @@ -127,7 +129,7 @@ def has_len_all_ranks( def _update_dataloader( - dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None + dataloader: DataLoader, sampler: Sampler | Iterable, mode: RunningStage | None = None ) -> DataLoader: dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) return _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) @@ -135,10 +137,10 @@ def _update_dataloader( def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, - sampler: Union[Sampler, Iterable], - mode: Optional[RunningStage] = None, + sampler: Sampler | Iterable, + mode: RunningStage | None = None, disallow_batch_sampler: bool = False, -) -> Tuple[Tuple[Any], Dict[str, Any]]: +) -> tuple[tuple[Any], dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -230,10 +232,10 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, - sampler: Union[Sampler, Iterable], - mode: Optional[RunningStage] = None, + sampler: Sampler | Iterable, + mode: RunningStage | None = None, disallow_batch_sampler: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-instantiation. diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py index c6f0d062df209..4acd2d0678054 100644 --- a/src/lightning/pytorch/utilities/grads.py +++ b/src/lightning/pytorch/utilities/grads.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities to describe gradients.""" -from typing import Dict, Union +from __future__ import annotations import torch from torch.nn import Module -def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> Dict[str, float]: +def grad_norm(module: Module, norm_type: float | int | str, group_separator: str = "/") -> dict[str, float]: """Compute each parameter's gradient's norm and their overall norm. The overall norm is computed over all gradients together, as if they diff --git a/src/lightning/pytorch/utilities/meta.py b/src/lightning/pytorch/utilities/meta.py index cd3999a690e70..2188195743f45 100644 --- a/src/lightning/pytorch/utilities/meta.py +++ b/src/lightning/pytorch/utilities/meta.py @@ -11,19 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping, Optional, Union +from __future__ import annotations + +from typing import Mapping from lightning_utilities.core.imports import module_available from torch import Tensor from torch.nn import Module, Parameter -def _is_deferred(module: Optional[Module]) -> bool: +def _is_deferred(module: Module | None) -> bool: if module is None or not module_available("torchdistx.fake"): return False from torchdistx.fake import is_fake - def any_fake(tensors: Mapping[str, Optional[Union[Tensor, Parameter]]]) -> bool: + def any_fake(tensors: Mapping[str, Tensor | Parameter | None]) -> bool: return any(is_fake(t) for t in tensors.values() if t is not None) is_deferred = any(_is_deferred(m) for m in module.children()) diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 40803650e5c42..4c22c12c71dd9 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -28,8 +28,10 @@ cp model.ckpt model.ckpt.backup python -m lightning.pytorch.utilities.upgrade_checkpoint model.ckpt """ +from __future__ import annotations + import re -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks.early_stopping import EarlyStopping @@ -39,7 +41,7 @@ _CHECKPOINT = Dict[str, Any] -def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: +def _migration_index() -> dict[str, list[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], @@ -127,7 +129,7 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: return checkpoint -def _get_fit_loop_initial_state_1_6_0() -> Dict: +def _get_fit_loop_initial_state_1_6_0() -> dict: return { "epoch_loop.batch_loop.manual_loop.optim_step_progress": { "current": {"completed": 0, "ready": 0}, diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 56f018a9fe900..36b867445bb43 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import sys import threading from types import ModuleType, TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict from packaging.version import Version @@ -33,8 +35,8 @@ def migrate_checkpoint( - checkpoint: _CHECKPOINT, target_version: Optional[str] = None -) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: + checkpoint: _CHECKPOINT, target_version: str | None = None +) -> tuple[_CHECKPOINT, dict[str, list[str]]]: """Applies Lightning version migrations to a checkpoint dictionary. Args: @@ -86,7 +88,7 @@ class pl_legacy_patch: torch.load("path/to/legacy/checkpoint.ckpt") """ - def __enter__(self) -> "pl_legacy_patch": + def __enter__(self) -> pl_legacy_patch: _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils") @@ -99,9 +101,9 @@ def __enter__(self) -> "pl_legacy_patch": def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, ) -> None: if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") @@ -109,7 +111,7 @@ def __exit__( _lock.release() -def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: +def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: _PATH | None = None) -> _CHECKPOINT: """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user. This function is used by the Lightning Trainer when resuming from a checkpoint. @@ -148,7 +150,7 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None: checkpoint.setdefault("legacy_pytorch-lightning_version", version) -def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool: +def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: str | None = None) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" target_version = Version(target) is_lte_max_version = max_version is None or target_version <= Version(max_version) diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index d683e294b235b..d57e3cebc9e8b 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Type +from __future__ import annotations + +from typing import Any from lightning_utilities.core.imports import RequirementCache from torch import nn @@ -19,7 +21,7 @@ import lightning.pytorch as pl -def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None) -> bool: +def is_overridden(method_name: str, instance: object | None = None, parent: type[object] | None = None) -> bool: if instance is None: # if `self.lightning_module` was passed as instance, it can be `None` return False diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 4476ac5b25fab..daffa6c510de5 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -13,11 +13,13 @@ # limitations under the License. """Utilities related to model weights summary.""" +from __future__ import annotations + import contextlib import logging import math from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any import torch import torch.nn as nn @@ -71,13 +73,13 @@ def __init__(self, module: nn.Module) -> None: super().__init__() self._module = module self._hook_handle = self._register_hook() - self._in_size: Optional[Union[str, List]] = None - self._out_size: Optional[Union[str, List]] = None + self._in_size: str | list | None = None + self._out_size: str | list | None = None def __del__(self) -> None: self.detach_hook() - def _register_hook(self) -> Optional[RemovableHandle]: + def _register_hook(self) -> RemovableHandle | None: """Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. Registering hooks on :class:`~torch.jit.ScriptModule` is not @@ -121,11 +123,11 @@ def detach_hook(self) -> None: self._hook_handle.remove() @property - def in_size(self) -> Union[str, List]: + def in_size(self) -> str | list: return self._in_size or UNKNOWN_SIZE @property - def out_size(self) -> Union[str, List]: + def out_size(self) -> str | list: return self._out_size or UNKNOWN_SIZE @property @@ -196,7 +198,7 @@ class ModelSummary: 0.530 Total estimated model params size (MB) """ - def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: + def __init__(self, model: pl.LightningModule, max_depth: int = 1) -> None: self._model = model if not isinstance(max_depth, int) or max_depth < -1: @@ -211,8 +213,8 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: self._precision_megabytes = (precision / 8.0) * 1e-6 @property - def named_modules(self) -> List[Tuple[str, nn.Module]]: - mods: List[Tuple[str, nn.Module]] + def named_modules(self) -> list[tuple[str, nn.Module]]: + mods: list[tuple[str, nn.Module]] if self._max_depth == 0: mods = [] elif self._max_depth == 1: @@ -224,23 +226,23 @@ def named_modules(self) -> List[Tuple[str, nn.Module]]: return mods @property - def layer_names(self) -> List[str]: + def layer_names(self) -> list[str]: return list(self._layer_summary.keys()) @property - def layer_types(self) -> List[str]: + def layer_types(self) -> list[str]: return [layer.layer_type for layer in self._layer_summary.values()] @property - def in_sizes(self) -> List: + def in_sizes(self) -> list: return [layer.in_size for layer in self._layer_summary.values()] @property - def out_sizes(self) -> List: + def out_sizes(self) -> list: return [layer.out_size for layer in self._layer_summary.values()] @property - def param_nums(self) -> List[int]: + def param_nums(self) -> list[int]: return [layer.num_parameters for layer in self._layer_summary.values()] @property @@ -261,7 +263,7 @@ def total_layer_params(self) -> int: def model_size(self) -> float: return self.total_parameters * self._precision_megabytes - def summarize(self) -> Dict[str, LayerSummary]: + def summarize(self) -> dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -299,7 +301,7 @@ def _forward_example_input(self) -> None: model(input_) model.train(mode) # restore mode of module - def _get_summary_data(self) -> List[Tuple[str, List[str]]]: + def _get_summary_data(self) -> list[tuple[str, list[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -320,7 +322,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]: return arrays - def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" layer_summaries = dict(arrays) layer_summaries[" "].append(" ") @@ -345,7 +347,7 @@ def __repr__(self) -> str: return str(self) -def parse_batch_shape(batch: Any) -> Union[str, List]: +def parse_batch_shape(batch: Any) -> str | list: if hasattr(batch, "shape"): return list(batch.shape) @@ -359,7 +361,7 @@ def _format_summary_table( total_parameters: int, trainable_parameters: int, model_size: float, - *cols: Tuple[str, List[str]], + *cols: tuple[str, list[str]], ) -> str: """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big string defining the summary table that are nicely formatted.""" @@ -452,7 +454,7 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool: return False -def summarize(lightning_module: "pl.LightningModule", max_depth: int = 1) -> ModelSummary: +def summarize(lightning_module: pl.LightningModule, max_depth: int = 1) -> ModelSummary: """Summarize the LightningModule specified by `lightning_module`. Args: diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py index fc84fbeb54b89..6cc225b692d0a 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py @@ -13,8 +13,9 @@ # limitations under the License. """Utilities that can be used with Deepspeed.""" +from __future__ import annotations + from collections import OrderedDict -from typing import Dict, List, Tuple import torch from lightning_utilities.core.imports import RequirementCache @@ -51,7 +52,7 @@ def partitioned_size(p: Parameter) -> int: class DeepSpeedSummary(ModelSummary): - def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[override] + def summarize(self) -> dict[str, DeepSpeedLayerSummary]: # type: ignore[override] summary = OrderedDict((name, DeepSpeedLayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -78,10 +79,10 @@ def trainable_parameters(self) -> int: ) @property - def parameters_per_layer(self) -> List[int]: + def parameters_per_layer(self) -> list[int]: return [layer.average_shard_parameters for layer in self._layer_summary.values()] - def _get_summary_data(self) -> List[Tuple[str, List[str]]]: + def _get_summary_data(self) -> list[tuple[str, list[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -103,7 +104,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]: return arrays - def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" super()._add_leftover_params_to_summary(arrays, total_leftover_params) layer_summaries = dict(arrays) diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py index 9b12b456db451..2fd08112c3f68 100644 --- a/src/lightning/pytorch/utilities/parameter_tying.py +++ b/src/lightning/pytorch/utilities/parameter_tying.py @@ -16,17 +16,17 @@ Reference: https://github.com/pytorch/fairseq/blob/1f7ef9ed1e1061f8c7f88f8b94c7186834398690/fairseq/trainer.py#L110-L118 """ -from typing import Dict, List, Optional +from __future__ import annotations from torch import nn -def find_shared_parameters(module: nn.Module) -> List[str]: +def find_shared_parameters(module: nn.Module) -> list[str]: """Returns a list of names of shared parameters set in the module.""" return _find_shared_parameters(module) -def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[Dict] = None, prefix: str = "") -> List[str]: +def _find_shared_parameters(module: nn.Module, tied_parameters: dict | None = None, prefix: str = "") -> list[str]: if tied_parameters is None: tied_parameters = {} for name, param in module._parameters.items(): diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 9958bc9f249a7..9f78ed3019f4f 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -13,12 +13,14 @@ # limitations under the License. """Utilities used for parameter parsing.""" +from __future__ import annotations + import copy import inspect import pickle import types from dataclasses import fields, is_dataclass -from typing import Any, Dict, List, Literal, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, Literal, MutableMapping, Sequence from torch import nn @@ -44,7 +46,7 @@ def clean_namespace(hparams: MutableMapping) -> None: del hparams[k] -def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]]: +def parse_class_init_keys(cls: type) -> tuple[str, str | None, str | None]: """Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: @@ -63,9 +65,9 @@ def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]] n_self = init_params[0].name def _get_first_if_any( - params: List[inspect.Parameter], + params: list[inspect.Parameter], param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], - ) -> Optional[str]: + ) -> str | None: for p in params: if p.kind == param_type: return p.name @@ -77,13 +79,13 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> Dict[str, Any]: # pragma: no-cover +def get_init_args(frame: types.FrameType) -> dict[str, Any]: # pragma: no-cover """For backwards compatibility: #16369.""" _, local_args = _get_init_args(frame) return local_args -def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: +def _get_init_args(frame: types.FrameType) -> tuple[Any | None, dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: return None, {} @@ -104,10 +106,10 @@ def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any def collect_init_args( frame: types.FrameType, - path_args: List[Dict[str, Any]], + path_args: list[dict[str, Any]], inside: bool = False, - classes: Tuple[Type, ...] = (), -) -> List[Dict[str, Any]]: + classes: tuple[type, ...] = (), +) -> list[dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. Args: @@ -137,7 +139,7 @@ def collect_init_args( def save_hyperparameters( - obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None + obj: Any, *args: Any, ignore: Sequence[str] | str | None = None, frame: types.FrameType | None = None ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -218,7 +220,7 @@ class AttributeDict(Dict): "new_key": 42 """ - def __getattr__(self, key: str) -> Optional[Any]: + def __getattr__(self, key: str) -> Any | None: try: return self[key] except KeyError as exp: @@ -236,13 +238,13 @@ def __repr__(self) -> str: return "\n".join(rows) -def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> List[Any]: +def _lightning_get_all_attr_holders(model: pl.LightningModule, attribute: str) -> list[Any]: """Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - holders: List[Any] = [] + holders: list[Any] = [] # Check if attribute in model if hasattr(model, attribute): @@ -264,7 +266,7 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) return holders -def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str) -> Optional[Any]: +def _lightning_get_first_attr_holder(model: pl.LightningModule, attribute: str) -> Any | None: """Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams @@ -277,7 +279,7 @@ def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str return holders[-1] -def lightning_hasattr(model: "pl.LightningModule", attribute: str) -> bool: +def lightning_hasattr(model: pl.LightningModule, attribute: str) -> bool: """Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -285,7 +287,7 @@ def lightning_hasattr(model: "pl.LightningModule", attribute: str) -> bool: return _lightning_get_first_attr_holder(model, attribute) is not None -def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[Any]: +def lightning_getattr(model: pl.LightningModule, attribute: str) -> Any | None: """Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -306,7 +308,7 @@ def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[A return getattr(holder, attribute) -def lightning_setattr(model: "pl.LightningModule", attribute: str, value: Any) -> None: +def lightning_setattr(model: pl.LightningModule, attribute: str, value: Any) -> None: """Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. Will also set the attribute on datamodule, if it exists. diff --git a/src/lightning/pytorch/utilities/signature_utils.py b/src/lightning/pytorch/utilities/signature_utils.py index 0f41c5948fb46..be127e2cdd771 100644 --- a/src/lightning/pytorch/utilities/signature_utils.py +++ b/src/lightning/pytorch/utilities/signature_utils.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect -from typing import Callable, Optional +from typing import Callable def is_param_in_hook_signature( - hook_fx: Callable, param: str, explicit: bool = False, min_args: Optional[int] = None + hook_fx: Callable, param: str, explicit: bool = False, min_args: int | None = None ) -> bool: """ Args: diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 732bc26cf5e8e..cab1dc76f6ad7 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple +from __future__ import annotations from lightning_utilities.core.imports import RequirementCache @@ -28,12 +28,12 @@ def _runif_reasons( *, min_cuda_gpus: int = 0, - min_torch: Optional[str] = None, - max_torch: Optional[str] = None, - min_python: Optional[str] = None, + min_torch: str | None = None, + max_torch: str | None = None, + min_python: str | None = None, bf16_cuda: bool = False, tpu: bool = False, - mps: Optional[bool] = None, + mps: bool | None = None, skip_windows: bool = False, standalone: bool = False, deepspeed: bool = False, @@ -43,7 +43,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, -) -> Tuple[List[str], Dict[str, bool]]: +) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. Args: diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index f8c3d5777cd40..8ba4fc7fe1246 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -16,9 +16,11 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from __future__ import annotations + from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Generator, List, Mapping, Optional, Protocol, runtime_checkable, Type, Union +from typing import Any, Generator, List, Mapping, Protocol, runtime_checkable, Type, Union import torch from torch import Tensor @@ -42,11 +44,11 @@ class DistributedDataParallel(Protocol): def __init__( self, module: torch.nn.Module, - device_ids: Optional[List[Union[int, torch.device]]] = None, - output_device: Optional[Union[int, torch.device]] = None, + device_ids: list[int | torch.device] | None = None, + output_device: int | torch.device | None = None, dim: int = 0, broadcast_buffers: bool = True, - process_group: Optional[ProcessGroup] = None, + process_group: ProcessGroup | None = None, bucket_cap_mb: int = 25, find_unused_parameters: bool = False, check_reduction: bool = False, @@ -69,9 +71,9 @@ def no_sync(self) -> Generator: @dataclass class LRSchedulerConfig: - scheduler: Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau] + scheduler: _TORCH_LRSCHEDULER | ReduceLROnPlateau # no custom name - name: Optional[str] = None + name: str | None = None # after epoch is over interval: str = "epoch" # every epoch/batch @@ -79,6 +81,6 @@ class LRSchedulerConfig: # most often not ReduceLROnPlateau scheduler reduce_on_plateau: bool = False # value to monitor for ReduceLROnPlateau - monitor: Optional[str] = None + monitor: str | None = None # enforce that the monitor exists for ReduceLROnPlateau strict: bool = True diff --git a/src/lightning/pytorch/utilities/upgrade_checkpoint.py b/src/lightning/pytorch/utilities/upgrade_checkpoint.py index 87ad6031f9f24..e924c72cc7586 100644 --- a/src/lightning/pytorch/utilities/upgrade_checkpoint.py +++ b/src/lightning/pytorch/utilities/upgrade_checkpoint.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import glob import logging from argparse import ArgumentParser, Namespace from pathlib import Path from shutil import copyfile -from typing import List import torch from tqdm import tqdm @@ -29,7 +30,7 @@ def _upgrade(args: Namespace) -> None: path = Path(args.path).absolute() extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" - files: List[Path] = [] + files: list[Path] = [] if not path.exists(): _log.error(