diff --git a/.actions/assistant.py b/.actions/assistant.py index 15a20e63c61dc..dc1fa055d5e62 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -18,10 +18,11 @@ import shutil import tempfile import urllib.request +from collections.abc import Iterable, Iterator, Sequence from itertools import chain from os.path import dirname, isfile from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Optional from packaging.requirements import Requirement from packaging.version import Version @@ -127,7 +128,7 @@ def _parse_requirements(lines: Iterable[str]) -> Iterator[_RequirementWithCommen pip_argument = None -def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]: +def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> list[str]: """Loading requirements from a file. >>> path_req = os.path.join(_PROJECT_ROOT, "requirements") @@ -222,7 +223,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme fp.writelines([ln + os.linesep for ln in requires] + [os.linesep]) -def _retrieve_files(directory: str, *ext: str) -> List[str]: +def _retrieve_files(directory: str, *ext: str) -> list[str]: all_files = [] for root, _, files in os.walk(directory): for fname in files: @@ -232,7 +233,7 @@ def _retrieve_files(directory: str, *ext: str) -> List[str]: return all_files -def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning_by: str = "") -> List[str]: +def _replace_imports(lines: list[str], mapping: list[tuple[str, str]], lightning_by: str = "") -> list[str]: """Replace imports of standalone package to lightning. >>> lns = [ @@ -320,7 +321,7 @@ def copy_replace_imports( fo.writelines(lines) -def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None: +def create_mirror_package(source_dir: str, package_mapping: dict[str, str]) -> None: """Create a mirror package with adjusted imports.""" # replace imports and copy the code mapping = package_mapping.copy() diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml index 15d226eed7fec..b6af39d3313ab 100644 --- a/.github/workflows/_legacy-checkpoints.yml +++ b/.github/workflows/_legacy-checkpoints.yml @@ -60,7 +60,7 @@ jobs: - uses: actions/setup-python@v5 with: # Python version here needs to be supported by all PL versions listed in back-compatible-versions.txt. - python-version: 3.8 + python-version: "3.9" - name: Install PL from source env: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24fc40566b152..c5e65de1d7eb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -74,7 +74,7 @@ repos: hooks: # try to fix what is possible - id: ruff - args: ["--fix"] + args: ["--fix", "--unsafe-fixes"] # perform formatting updates - id: ruff-format # validate if all is fine with preview mode diff --git a/docs/source-pytorch/accelerators/tpu_faq.rst b/docs/source-pytorch/accelerators/tpu_faq.rst index 109449ef2cc9a..766f9dcacb32e 100644 --- a/docs/source-pytorch/accelerators/tpu_faq.rst +++ b/docs/source-pytorch/accelerators/tpu_faq.rst @@ -40,9 +40,9 @@ Unsupported datatype transfer to TPUs? .. code-block:: - File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 205, in _for_each_instance_rewrite + File "/usr/local/lib/python3.9/dist-packages/torch_xla/utils/utils.py", line 205, in _for_each_instance_rewrite v = _for_each_instance_rewrite(result.__dict__[k], select_fn, fn, rwmap) - File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 206, in _for_each_instance_rewrite + File "/usr/local/lib/python3.9/dist-packages/torch_xla/utils/utils.py", line 206, in _for_each_instance_rewrite result.__dict__[k] = v TypeError: 'mappingproxy' object does not support item assignment diff --git a/docs/source-pytorch/advanced/post_training_quantization.rst b/docs/source-pytorch/advanced/post_training_quantization.rst index 504a57a4191cf..f925c6ccd47b4 100644 --- a/docs/source-pytorch/advanced/post_training_quantization.rst +++ b/docs/source-pytorch/advanced/post_training_quantization.rst @@ -33,7 +33,7 @@ Installation Prerequisites ============= -Python version: 3.8, 3.9, 3.10 +Python version: 3.9, 3.10 Install IntelĀ® Neural Compressor ================================ diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index 7af01ede054a8..f4f31c114f084 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -1,7 +1,7 @@ import os -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from functools import partial -from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast +from typing import Any, Literal, Optional, Union, cast import lightning as L import torch @@ -19,11 +19,11 @@ def __init__( self, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", precision: Union[str, int] = "32-true", plugins: Optional[Union[str, Any]] = None, - callbacks: Optional[Union[List[Any], Any]] = None, - loggers: Optional[Union[Logger, List[Logger]]] = None, + callbacks: Optional[Union[list[Any], Any]] = None, + loggers: Optional[Union[Logger, list[Logger]]] = None, max_epochs: Optional[int] = 1000, max_steps: Optional[int] = None, grad_accum_steps: int = 1, @@ -465,7 +465,7 @@ def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: def _parse_optimizers_schedulers( self, configure_optim_output - ) -> Tuple[ + ) -> tuple[ Optional[L.fabric.utilities.types.Optimizable], Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]], ]: diff --git a/examples/fabric/reinforcement_learning/rl/agent.py b/examples/fabric/reinforcement_learning/rl/agent.py index fcf5bd0b9371d..16a4cd6d86c73 100644 --- a/examples/fabric/reinforcement_learning/rl/agent.py +++ b/examples/fabric/reinforcement_learning/rl/agent.py @@ -1,5 +1,4 @@ import math -from typing import Dict, Tuple import gymnasium as gym import torch @@ -43,7 +42,7 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_ layer_init(torch.nn.Linear(64, envs.single_action_space.n), std=0.01, ortho_init=ortho_init), ) - def get_action(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]: + def get_action(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor]: logits = self.actor(x) distribution = Categorical(logits=logits) if action is None: @@ -58,12 +57,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor: def get_value(self, x: Tensor) -> Tensor: return self.critic(x) - def get_action_and_value(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def get_action_and_value(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: action, log_prob, entropy = self.get_action(x, action) value = self.get_value(x) return action, log_prob, entropy, value - def forward(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def forward(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: return self.get_action_and_value(x, action) @torch.no_grad() @@ -77,7 +76,7 @@ def estimate_returns_and_advantages( num_steps: int, gamma: float, gae_lambda: float, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: next_value = self.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards) lastgaelam = 0 @@ -143,7 +142,7 @@ def __init__( self.avg_value_loss = MeanMetric(**torchmetrics_kwargs) self.avg_ent_loss = MeanMetric(**torchmetrics_kwargs) - def get_action(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]: + def get_action(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor]: logits = self.actor(x) distribution = Categorical(logits=logits) if action is None: @@ -158,12 +157,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor: def get_value(self, x: Tensor) -> Tensor: return self.critic(x) - def get_action_and_value(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def get_action_and_value(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: action, log_prob, entropy = self.get_action(x, action) value = self.get_value(x) return action, log_prob, entropy, value - def forward(self, x: Tensor, action: Tensor = None) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + def forward(self, x: Tensor, action: Tensor = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: return self.get_action_and_value(x, action) @torch.no_grad() @@ -177,7 +176,7 @@ def estimate_returns_and_advantages( num_steps: int, gamma: float, gae_lambda: float, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: next_value = self.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards) lastgaelam = 0 @@ -193,7 +192,7 @@ def estimate_returns_and_advantages( returns = advantages + values return returns, advantages - def training_step(self, batch: Dict[str, Tensor]): + def training_step(self, batch: dict[str, Tensor]): # Get actions and values given the current observations _, newlogprob, entropy, newvalue = self(batch["obs"], batch["actions"].long()) logratio = newlogprob - batch["logprobs"] diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index 068359602a096..4df52d7cd0455 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -21,7 +21,6 @@ import os import time from datetime import datetime -from typing import Dict import gymnasium as gym import torch @@ -38,7 +37,7 @@ def train( fabric: Fabric, agent: PPOLightningAgent, optimizer: torch.optim.Optimizer, - data: Dict[str, Tensor], + data: dict[str, Tensor], global_step: int, args: argparse.Namespace, ): diff --git a/examples/fabric/reinforcement_learning/train_torch.py b/examples/fabric/reinforcement_learning/train_torch.py index cf74e03f5202e..dad16ad10fb0b 100644 --- a/examples/fabric/reinforcement_learning/train_torch.py +++ b/examples/fabric/reinforcement_learning/train_torch.py @@ -22,7 +22,6 @@ import random import time from datetime import datetime -from typing import Dict import gymnasium as gym import torch @@ -41,7 +40,7 @@ def train( agent: PPOAgent, optimizer: torch.optim.Optimizer, - data: Dict[str, Tensor], + data: dict[str, Tensor], logger: SummaryWriter, global_step: int, args: argparse.Namespace, diff --git a/examples/fabric/tensor_parallel/model.py b/examples/fabric/tensor_parallel/model.py index 3c9e7de472b90..71f2634867e9b 100644 --- a/examples/fabric/tensor_parallel/model.py +++ b/examples/fabric/tensor_parallel/model.py @@ -9,7 +9,7 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -87,7 +87,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index 260cce4a548b1..d6f594b12f57b 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -18,7 +18,7 @@ """ from os import path -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -45,7 +45,7 @@ def __init__( nrow: int = 8, padding: int = 2, normalize: bool = True, - value_range: Optional[Tuple[int, int]] = None, + value_range: Optional[tuple[int, int]] = None, scale_each: bool = False, pad_value: int = 0, ) -> None: diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py index 497cb658c275f..b3bfaaea93e7f 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py +++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py @@ -35,7 +35,7 @@ import argparse import random from collections import OrderedDict, deque, namedtuple -from typing import Iterator, List, Tuple +from collections.abc import Iterator import gym import torch @@ -102,7 +102,7 @@ def append(self, experience: Experience) -> None: """ self.buffer.append(experience) - def sample(self, batch_size: int) -> Tuple: + def sample(self, batch_size: int) -> tuple: indices = random.sample(range(len(self.buffer)), batch_size) states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices)) @@ -190,7 +190,7 @@ def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: return action @torch.no_grad() - def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -> Tuple[float, bool]: + def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -> tuple[float, bool]: """Carries out a single interaction step between the agent and the environment. Args: @@ -295,7 +295,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ return self.net(x) - def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + def dqn_mse_loss(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """Calculates the mse loss using a mini batch from the replay buffer. Args: @@ -318,7 +318,7 @@ def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor return nn.MSELoss()(state_action_values, expected_state_action_values) - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch received. @@ -356,7 +356,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O return OrderedDict({"loss": loss, "log": log, "progress_bar": log}) - def configure_optimizers(self) -> List[Optimizer]: + def configure_optimizers(self) -> list[Optimizer]: """Initialize Adam optimizer.""" optimizer = optim.Adam(self.net.parameters(), lr=self.lr) return [optimizer] diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index bc3f8c1b9b193..1fb083894c284 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -30,7 +30,8 @@ """ import argparse -from typing import Callable, Iterator, List, Tuple +from collections.abc import Iterator +from typing import Callable import gym import torch @@ -41,7 +42,7 @@ from torch.utils.data import DataLoader, IterableDataset -def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): +def create_mlp(input_shape: tuple[int], n_actions: int, hidden_size: int = 128): """Simple Multi-Layer Perceptron network.""" return nn.Sequential( nn.Linear(input_shape[0], hidden_size), @@ -227,7 +228,7 @@ def __init__( self.state = torch.FloatTensor(self.env.reset()) - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Passes in a state x through the network and returns the policy and a sampled action. Args: @@ -242,7 +243,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te return pi, action, value - def discount_rewards(self, rewards: List[float], discount: float) -> List[float]: + def discount_rewards(self, rewards: list[float], discount: float) -> list[float]: """Calculate the discounted rewards of all rewards in list. Args: @@ -263,7 +264,7 @@ def discount_rewards(self, rewards: List[float], discount: float) -> List[float] return list(reversed(cumul_reward)) - def calc_advantage(self, rewards: List[float], values: List[float], last_value: float) -> List[float]: + def calc_advantage(self, rewards: list[float], values: list[float], last_value: float) -> list[float]: """Calculate the advantage given rewards, state values, and the last value of episode. Args: @@ -281,7 +282,7 @@ def calc_advantage(self, rewards: List[float], values: List[float], last_value: delta = [rews[i] + self.gamma * vals[i + 1] - vals[i] for i in range(len(rews) - 1)] return self.discount_rewards(delta, self.gamma * self.lam) - def generate_trajectory_samples(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + def generate_trajectory_samples(self) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: """ Contains the logic for generating trajectory data to train policy and value network Yield: @@ -375,7 +376,7 @@ def critic_loss(self, state, action, logp_old, qval, adv) -> torch.Tensor: value = self.critic(state) return (qval - value).pow(2).mean() - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor]): + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor]): """Carries out a single update to actor and critic network from a batch of replay buffer. Args: @@ -406,7 +407,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor]): self.log("loss_critic", loss_critic, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log("loss_actor", loss_actor, on_step=False, on_epoch=True, prog_bar=True, logger=True) - def configure_optimizers(self) -> List[Optimizer]: + def configure_optimizers(self) -> list[Optimizer]: """Initialize Adam optimizer.""" optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=self.lr_actor) optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=self.lr_critic) diff --git a/examples/pytorch/servable_module/production.py b/examples/pytorch/servable_module/production.py index f1d5a06e3c584..da0c42d12a865 100644 --- a/examples/pytorch/servable_module/production.py +++ b/examples/pytorch/servable_module/production.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from io import BytesIO from os import path -from typing import Dict, Optional +from typing import Optional import numpy as np import torch @@ -93,7 +93,7 @@ def configure_payload(self): def configure_serialization(self): return {"x": Image(224, 224).deserialize}, {"output": Top1().serialize} - def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + def serve_step(self, x: torch.Tensor) -> dict[str, torch.Tensor]: return {"output": self.model(x)} def configure_response(self): diff --git a/examples/pytorch/tensor_parallel/model.py b/examples/pytorch/tensor_parallel/model.py index 3c9e7de472b90..71f2634867e9b 100644 --- a/examples/pytorch/tensor_parallel/model.py +++ b/examples/pytorch/tensor_parallel/model.py @@ -9,7 +9,7 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -87,7 +87,7 @@ def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided diff --git a/pyproject.toml b/pyproject.toml index da4cd7f197d5a..48439bee75332 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ ignore-words-list = "te, compiletime" [tool.ruff] line-length = 120 -target-version = "py38" +target-version = "py39" # Exclude a variety of commonly ignored directories. exclude = [ ".git", diff --git a/setup.py b/setup.py index bfc329bb8fe88..92f0265eafb9f 100755 --- a/setup.py +++ b/setup.py @@ -45,9 +45,10 @@ import logging import os import tempfile +from collections.abc import Generator, Mapping from importlib.util import module_from_spec, spec_from_file_location from types import ModuleType -from typing import Generator, Mapping, Optional +from typing import Optional import setuptools import setuptools.command.egg_info diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index 09eab5601f443..2d3bb0e7d1f33 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any, Dict +from typing import Any from setuptools import find_namespace_packages @@ -26,7 +26,7 @@ def _load_py_module(name: str, location: str) -> ModuleType: _ASSISTANT = _load_py_module(name="assistant", location=os.path.join(_PROJECT_ROOT, ".actions", "assistant.py")) -def _prepare_extras() -> Dict[str, Any]: +def _prepare_extras() -> dict[str, Any]: # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. # From remote, use like `pip install "lightning[dev, docs]"` @@ -63,7 +63,7 @@ def _prepare_extras() -> Dict[str, Any]: return extras -def _setup_args() -> Dict[str, Any]: +def _setup_args() -> dict[str, Any]: about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) long_description = _ASSISTANT.load_readme_description( diff --git a/src/lightning/fabric/accelerators/cpu.py b/src/lightning/fabric/accelerators/cpu.py index 0334210ecd76d..2997d1ada3352 100644 --- a/src/lightning/fabric/accelerators/cpu.py +++ b/src/lightning/fabric/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union import torch from typing_extensions import override @@ -45,7 +45,7 @@ def parse_devices(devices: Union[int, str]) -> int: @staticmethod @override - def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 4afc9be723fc2..5b8a4c2f80bed 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import lru_cache -from typing import List, Optional, Union +from typing import Optional, Union import torch from typing_extensions import override @@ -43,7 +43,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -51,7 +51,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: @staticmethod @override - def get_parallel_devices(devices: List[int]) -> List[torch.device]: + def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @@ -76,7 +76,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def find_usable_cuda_devices(num_devices: int = -1) -> List[int]: +def find_usable_cuda_devices(num_devices: int = -1) -> list[int]: """Returns a list of all available and usable CUDA GPU devices. A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function @@ -129,7 +129,7 @@ def find_usable_cuda_devices(num_devices: int = -1) -> List[int]: return available_devices -def _get_all_visible_cuda_devices() -> List[int]: +def _get_all_visible_cuda_devices() -> list[int]: """Returns a list of all visible CUDA GPU devices. Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index 75497169cda0f..b535ba57ed4cb 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -14,7 +14,7 @@ import os import platform from functools import lru_cache -from typing import List, Optional, Union +from typing import Optional, Union import torch from typing_extensions import override @@ -46,7 +46,7 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -54,7 +54,7 @@ def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None @@ -84,7 +84,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def _get_all_available_mps_gpus() -> List[int]: +def _get_all_available_mps_gpus() -> list[int]: """ Returns: A list of all available MPS GPUs diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 1299b1e148aa8..17d5233336d50 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional from typing_extensions import override @@ -68,7 +68,7 @@ def register( if name in self and not override: raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - data: Dict[str, Any] = {} + data: dict[str, Any] = {} data["description"] = description data["init_params"] = init_params @@ -107,7 +107,7 @@ def remove(self, name: str) -> None: """Removes the registered accelerator by name.""" self.pop(name) - def available_accelerators(self) -> List[str]: + def available_accelerators(self) -> list[str]: """Returns a list of registered accelerators.""" return list(self.keys()) diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index 38d7380dc7905..4a1f25a91062b 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, List, Union +from typing import Any, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -47,13 +47,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Accelerator device parsing logic.""" return _parse_tpu_devices(devices) @staticmethod @override - def get_parallel_devices(devices: Union[int, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_tpu_devices(devices) if isinstance(devices, int): @@ -115,7 +115,7 @@ def _using_pjrt() -> bool: return pjrt.using_pjrt() -def _parse_tpu_devices(devices: Union[int, str, List[int]]) -> Union[int, List[int]]: +def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]: """Parses the TPU devices given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`. @@ -152,7 +152,7 @@ def _check_tpu_devices_valid(devices: object) -> None: ) -def _parse_tpu_devices_str(devices: str) -> Union[int, List[int]]: +def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]: devices = devices.strip() try: return int(devices) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 7c81afa916196..5f18884e83d79 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -17,7 +17,7 @@ import subprocess import sys from argparse import Namespace -from typing import Any, List, Optional +from typing import Any, Optional import torch from lightning_utilities.core.imports import RequirementCache @@ -39,7 +39,7 @@ _SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") -def _get_supported_strategies() -> List[str]: +def _get_supported_strategies() -> list[str]: """Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the CLI or ones that require further configuration by the user.""" available_strategies = STRATEGY_REGISTRY.available_strategies() @@ -221,7 +221,7 @@ def _get_num_processes(accelerator: str, devices: str) -> int: return len(parsed_devices) if parsed_devices is not None else 0 -def _torchrun_launch(args: Namespace, script_args: List[str]) -> None: +def _torchrun_launch(args: Namespace, script_args: list[str]) -> None: """This will invoke `torchrun` programmatically to launch the given script in new processes.""" import torch.distributed.run as torchrun @@ -242,7 +242,7 @@ def _torchrun_launch(args: Namespace, script_args: List[str]) -> None: torchrun.main(torchrun_args) -def main(args: Namespace, script_args: Optional[List[str]] = None) -> None: +def main(args: Namespace, script_args: Optional[list[str]] = None) -> None: _set_env_variables(args) _torchrun_launch(args, script_args or []) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 9fb66255830c6..9161d5f1bd6c2 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -13,7 +13,7 @@ # limitations under the License. import os from collections import Counter -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast import torch from typing_extensions import get_args @@ -99,10 +99,10 @@ def __init__( self, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, ) -> None: # These arguments can be set through environment variables set by the CLI accelerator = self._argument_from_env("accelerator", accelerator, default="auto") @@ -124,7 +124,7 @@ def __init__( self._precision_input: _PRECISION_INPUT_STR = "32-true" self._precision_instance: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: List[Union[int, torch.device, str]] = [] + self._parallel_devices: list[Union[int, torch.device, str]] = [] self.checkpoint_io: Optional[CheckpointIO] = None self._check_config_and_set_final_flags( @@ -165,7 +165,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]], ) -> None: """This method checks: @@ -224,7 +224,7 @@ def _check_config_and_set_final_flags( precision_input = _convert_precision_to_unified_args(precision) if plugins: - plugins_flags_types: Dict[str, int] = Counter() + plugins_flags_types: dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Precision): self._precision_instance = plugin @@ -295,7 +295,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 0ff5b04b30b0a..058e5e7c40751 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -13,20 +13,14 @@ # limitations under the License. import inspect import os -from contextlib import contextmanager, nullcontext +from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial from pathlib import Path from typing import ( Any, Callable, - ContextManager, - Dict, - Generator, - List, - Mapping, Optional, - Sequence, - Tuple, Union, cast, overload, @@ -118,12 +112,12 @@ def __init__( *, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, - callbacks: Optional[Union[List[Any], Any]] = None, - loggers: Optional[Union[Logger, List[Logger]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, + callbacks: Optional[Union[list[Any], Any]] = None, + loggers: Optional[Union[Logger, list[Logger]]] = None, ) -> None: self._connector = _Connector( accelerator=accelerator, @@ -192,7 +186,7 @@ def is_global_zero(self) -> bool: return self._strategy.is_global_zero @property - def loggers(self) -> List[Logger]: + def loggers(self) -> list[Logger]: """Returns all loggers passed to Fabric.""" return self._loggers @@ -326,7 +320,7 @@ def setup_module( self._models_setup += 1 return module - def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tuple[_FabricOptimizer, ...]]: + def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, tuple[_FabricOptimizer, ...]]: r"""Set up one or more optimizers for accelerated training. Some strategies do not allow setting up model and optimizer independently. For them, you should call @@ -349,7 +343,7 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu def setup_dataloaders( self, *dataloaders: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True - ) -> Union[DataLoader, List[DataLoader]]: + ) -> Union[DataLoader, list[DataLoader]]: r"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -489,7 +483,7 @@ def clip_gradients( ) raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!") - def autocast(self) -> ContextManager: + def autocast(self) -> AbstractContextManager: """A context manager to automatically convert operations for the chosen precision. Use this only if the `forward` method of your model does not cover all operations you wish to run with the @@ -564,8 +558,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return self._strategy.broadcast(obj, src=src) def all_gather( - self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, Dict, List, Tuple]: + self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[Tensor, dict, list, tuple]: """Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -589,10 +583,10 @@ def all_gather( def all_reduce( self, - data: Union[Tensor, Dict, List, Tuple], + data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean", - ) -> Union[Tensor, Dict, List, Tuple]: + ) -> Union[Tensor, dict, list, tuple]: """Reduce tensors or collections of tensors from multiple processes. The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. @@ -639,7 +633,7 @@ def rank_zero_first(self, local: bool = False) -> Generator: if rank == 0: barrier() - def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> ContextManager: + def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> AbstractContextManager: r"""Skip gradient synchronization during backward to avoid redundant communication overhead. Use this context manager when performing gradient accumulation to speed up training with multiple devices. @@ -681,7 +675,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte forward_module, _ = _unwrap_compiled(module._forward_module) return self._strategy._backward_sync_control.no_backward_sync(forward_module, enabled) - def sharded_model(self) -> ContextManager: + def sharded_model(self) -> AbstractContextManager: r"""Instantiate a model under this context manager to prepare it for model-parallel sharding. .. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead. @@ -693,12 +687,12 @@ def sharded_model(self) -> ContextManager: return self.strategy.module_sharded_context() return nullcontext() - def init_tensor(self) -> ContextManager: + def init_tensor(self) -> AbstractContextManager: """Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.""" return self._strategy.tensor_init_context() - def init_module(self, empty_init: Optional[bool] = None) -> ContextManager: + def init_module(self, empty_init: Optional[bool] = None) -> AbstractContextManager: """Instantiate the model and its parameters under this context manager to reduce peak memory usage. The parameters get created on the device and with the right data type right away without wasting memory being @@ -716,8 +710,8 @@ def init_module(self, empty_init: Optional[bool] = None) -> ContextManager: def save( self, path: Union[str, Path], - state: Dict[str, Union[nn.Module, Optimizer, Any]], - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + state: dict[str, Union[nn.Module, Optimizer, Any]], + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: r"""Save checkpoint contents to a file. @@ -750,9 +744,9 @@ def save( def load( self, path: Union[str, Path], - state: Optional[Dict[str, Union[nn.Module, Optimizer, Any]]] = None, + state: Optional[dict[str, Union[nn.Module, Optimizer, Any]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.) How and which processes load gets determined by the `strategy`. @@ -933,7 +927,7 @@ def _wrap_with_setup(self, to_run: Callable, *args: Any, **kwargs: Any) -> Any: with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler): return to_run(*args, **kwargs) - def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: + def _move_model_to_device(self, model: nn.Module, optimizers: list[Optimizer]) -> nn.Module: try: initial_name, initial_param = next(model.named_parameters()) except StopIteration: @@ -1061,7 +1055,7 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") @staticmethod - def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]: + def _configure_callbacks(callbacks: Optional[Union[list[Any], Any]]) -> list[Any]: callbacks = callbacks if callbacks is not None else [] callbacks = callbacks if isinstance(callbacks, list) else [callbacks] callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory")) diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index 4dbb56fa691db..dd7dfc63671f0 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -16,7 +16,7 @@ import logging import os from argparse import Namespace -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import override @@ -138,13 +138,13 @@ def experiment(self) -> "_ExperimentWriter": @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: raise NotImplementedError("The `CSVLogger` does not yet support logging hyperparameters.") @override @rank_zero_only def log_metrics( # type: ignore[override] - self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None + self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None ) -> None: metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) if step is None: @@ -200,8 +200,8 @@ class _ExperimentWriter: NAME_METRICS_FILE = "metrics.csv" def __init__(self, log_dir: str) -> None: - self.metrics: List[Dict[str, float]] = [] - self.metrics_keys: List[str] = [] + self.metrics: list[dict[str, float]] = [] + self.metrics_keys: list[str] = [] self._fs = get_filesystem(log_dir) self.log_dir = log_dir @@ -210,7 +210,7 @@ def __init__(self, log_dir: str) -> None: self._check_log_dir_exists() self._fs.makedirs(self.log_dir, exist_ok=True) - def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics_dict: dict[str, float], step: Optional[int] = None) -> None: """Record metrics.""" def _handle_value(value: Union[Tensor, Any]) -> Any: @@ -246,7 +246,7 @@ def save(self) -> None: self.metrics = [] # reset - def _record_new_keys(self) -> Set[str]: + def _record_new_keys(self) -> set[str]: """Records new keys that have not been logged before.""" current_keys = set().union(*self.metrics) new_keys = current_keys - set(self.metrics_keys) @@ -254,7 +254,7 @@ def _record_new_keys(self) -> Set[str]: self.metrics_keys.sort() return new_keys - def _rewrite_with_new_header(self, fieldnames: List[str]) -> None: + def _rewrite_with_new_header(self, fieldnames: list[str]) -> None: with self._fs.open(self.metrics_file_path, "r", newline="") as file: metrics = list(csv.DictReader(file)) diff --git a/src/lightning/fabric/loggers/logger.py b/src/lightning/fabric/loggers/logger.py index 5647ab9c1c7a2..39a9fa06a08d0 100644 --- a/src/lightning/fabric/loggers/logger.py +++ b/src/lightning/fabric/loggers/logger.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union from torch import Tensor from torch.nn import Module @@ -55,7 +55,7 @@ def group_separator(self) -> str: return "/" @abstractmethod - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: """Records metrics. This method logs metrics as soon as it received them. Args: @@ -66,7 +66,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> pass @abstractmethod - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: """Record hyperparameters. Args: diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 685c832088818..14bc3d6b76220 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -14,7 +14,8 @@ import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -219,7 +220,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) @override @rank_zero_only def log_hyperparams( - self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None + self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to @@ -318,12 +319,12 @@ def _get_next_version(self) -> int: return max(existing_versions) + 1 @staticmethod - def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: + def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: params = _utils_sanitize_params(params) # logging of arrays with dimension > 1 is not supported, sanitize as string return {k: str(v) if hasattr(v, "ndim") and v.ndim > 1 else v for k, v in params.items()} - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state["_experiment"] = None return state diff --git a/src/lightning/fabric/plugins/collectives/collective.py b/src/lightning/fabric/plugins/collectives/collective.py index 3b336b590aec8..9408fd87da400 100644 --- a/src/lightning/fabric/plugins/collectives/collective.py +++ b/src/lightning/fabric/plugins/collectives/collective.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Optional from torch import Tensor from typing_extensions import Self @@ -47,19 +47,19 @@ def all_reduce(self, tensor: Tensor, op: str) -> Tensor: ... def reduce(self, tensor: Tensor, dst: int, op: str) -> Tensor: ... @abstractmethod - def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: ... + def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: ... @abstractmethod - def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: ... + def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: ... @abstractmethod - def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: ... + def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: ... @abstractmethod - def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str) -> Tensor: ... + def reduce_scatter(self, output: Tensor, input_list: list[Tensor], op: str) -> Tensor: ... @abstractmethod - def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: ... + def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: ... @abstractmethod def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... @@ -68,7 +68,7 @@ def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None: ... def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor: ... @abstractmethod - def barrier(self, device_ids: Optional[List[int]] = None) -> None: ... + def barrier(self, device_ids: Optional[list[int]] = None) -> None: ... @classmethod @abstractmethod diff --git a/src/lightning/fabric/plugins/collectives/single_device.py b/src/lightning/fabric/plugins/collectives/single_device.py index 9b635f6cdc7c1..73378715f2454 100644 --- a/src/lightning/fabric/plugins/collectives/single_device.py +++ b/src/lightning/fabric/plugins/collectives/single_device.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from torch import Tensor from typing_extensions import override @@ -37,31 +37,31 @@ def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor: return tensor @override - def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]: + def all_gather(self, tensor_list: list[Tensor], tensor: Tensor, **__: Any) -> list[Tensor]: return [tensor] @override - def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]: + def gather(self, tensor: Tensor, *_: Any, **__: Any) -> list[Tensor]: return [tensor] @override def scatter( self, tensor: Tensor, - scatter_list: List[Tensor], + scatter_list: list[Tensor], *_: Any, **__: Any, ) -> Tensor: return scatter_list[0] @override - def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor: + def reduce_scatter(self, output: Tensor, input_list: list[Tensor], *_: Any, **__: Any) -> Tensor: return input_list[0] @override def all_to_all( - self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any - ) -> List[Tensor]: + self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor], *_: Any, **__: Any + ) -> list[Tensor]: return input_tensor_list @override diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 0dea3033f3dff..81e15a33cb983 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -1,6 +1,6 @@ import datetime import os -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch import torch.distributed as dist @@ -66,30 +66,30 @@ def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = return tensor @override - def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]: + def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]: dist.all_gather(tensor_list, tensor, group=self.group) return tensor_list @override - def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]: + def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]: dist.gather(tensor, gather_list, dst, group=self.group) return gather_list @override - def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor: + def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor: dist.scatter(tensor, scatter_list, src, group=self.group) return tensor @override def reduce_scatter( - self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" + self, output: Tensor, input_list: list[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum" ) -> Tensor: op = self._convert_to_native_op(op) dist.reduce_scatter(output, input_list, op=op, group=self.group) return output @override - def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]: + def all_to_all(self, output_tensor_list: list[Tensor], input_tensor_list: list[Tensor]) -> list[Tensor]: dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group) return output_tensor_list @@ -102,28 +102,28 @@ def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tenso dist.recv(tensor, src, tag=tag, group=self.group) # type: ignore[arg-type] return tensor - def all_gather_object(self, object_list: List[Any], obj: Any) -> List[Any]: + def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]: dist.all_gather_object(object_list, obj, group=self.group) return object_list def broadcast_object_list( - self, object_list: List[Any], src: int, device: Optional[torch.device] = None - ) -> List[Any]: + self, object_list: list[Any], src: int, device: Optional[torch.device] = None + ) -> list[Any]: dist.broadcast_object_list(object_list, src, group=self.group, device=device) return object_list - def gather_object(self, obj: Any, object_gather_list: List[Any], dst: int = 0) -> List[Any]: + def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]: dist.gather_object(obj, object_gather_list, dst, group=self.group) return object_gather_list def scatter_object_list( - self, scatter_object_output_list: List[Any], scatter_object_input_list: List[Any], src: int = 0 - ) -> List[Any]: + self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0 + ) -> list[Any]: dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) return scatter_object_output_list @override - def barrier(self, device_ids: Optional[List[int]] = None) -> None: + def barrier(self, device_ids: Optional[list[int]] = None) -> None: if self.group == dist.GroupMember.NON_GROUP_MEMBER: return dist.barrier(group=self.group, device_ids=device_ids) diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py index 6a23006d3c1d9..f0a07d61d9f03 100644 --- a/src/lightning/fabric/plugins/environments/lsf.py +++ b/src/lightning/fabric/plugins/environments/lsf.py @@ -14,7 +14,6 @@ import logging import os import socket -from typing import Dict, List from typing_extensions import override @@ -144,14 +143,14 @@ def _get_node_rank(self) -> int: """ hosts = self._read_hosts() - count: Dict[str, int] = {} + count: dict[str, int] = {} for host in hosts: if host not in count: count[host] = len(count) return count[socket.gethostname()] @staticmethod - def _read_hosts() -> List[str]: + def _read_hosts() -> list[str]: """Read compute hosts that are a part of the compute job. LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes. diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 79fc9e88b737e..3a33dac3335d1 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Optional from lightning.fabric.utilities.types import _PATH @@ -36,7 +36,7 @@ class CheckpointIO(ABC): """ @abstractmethod - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -47,7 +47,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio """ @abstractmethod - def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> dict[str, Any]: """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 02de1aa274a32..90a5f62ba7413 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional from typing_extensions import override @@ -34,7 +34,7 @@ class TorchCheckpointIO(CheckpointIO): """ @override - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -60,7 +60,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. Args: diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py index 5c154d81a9915..146fa2f33b510 100644 --- a/src/lightning/fabric/plugins/io/xla.py +++ b/src/lightning/fabric/plugins/io/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -41,7 +41,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @override - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index c624e821af28c..d5fc1f0c1cc2a 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Dict, Literal, Optional +from contextlib import AbstractContextManager +from typing import Any, Literal, Optional import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -59,7 +60,7 @@ def __init__( self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16 @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return torch.autocast(self.device, dtype=self._desired_input_dtype) @override @@ -93,13 +94,13 @@ def optimizer_step( return step_output @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 394415452890a..ecb1d8a442655 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -16,10 +16,11 @@ import math import os import warnings -from contextlib import ExitStack +from collections import OrderedDict +from contextlib import AbstractContextManager, ExitStack from functools import partial from types import ModuleType -from typing import Any, Callable, ContextManager, Literal, Optional, OrderedDict, Set, Tuple, Type, cast +from typing import Any, Callable, Literal, Optional, cast import torch from lightning_utilities import apply_to_collection @@ -70,7 +71,7 @@ def __init__( self, mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"], dtype: Optional[torch.dtype] = None, - ignore_modules: Optional[Set[str]] = None, + ignore_modules: Optional[set[str]] = None, ) -> None: _import_bitsandbytes() @@ -122,11 +123,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules raise RuntimeError( @@ -144,7 +145,7 @@ def module_init_context(self) -> ContextManager: return stack @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return _DtypeContextManager(self.dtype) @override @@ -175,7 +176,7 @@ def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _In def _replace_param( - param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[Tuple] = None + param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None ) -> torch.nn.Parameter: bnb = _import_bitsandbytes() @@ -418,7 +419,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: return bnb -def _convert_layers(module: torch.nn.Module, linear_cls: Type, ignore_modules: Set[str], prefix: str = "") -> None: +def _convert_layers(module: torch.nn.Module, linear_cls: type, ignore_modules: set[str], prefix: str = "") -> None: for name, child in module.named_children(): fullname = f"{prefix}.{name}" if prefix else name if isinstance(child, torch.nn.Linear) and not any(fullname.startswith(s) for s in ignore_modules): diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 2fcaa38258e3a..526095008f376 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, ContextManager, Literal +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -68,13 +68,13 @@ def convert_module(self, module: Module) -> Module: return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 0a857499f3d34..9aa0365a55e70 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager +from typing import Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -33,15 +34,15 @@ def convert_module(self, module: Module) -> Module: return module.double() @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(torch.double) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 179fc21cdd90d..0b78ad72a441f 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, ContextManager, Dict, Literal, Optional +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Any, Literal, Optional import torch from lightning_utilities import apply_to_collection @@ -100,15 +101,15 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": ) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return self.tensor_init_context() @@ -150,12 +151,12 @@ def unscale_gradients(self, optimizer: Optimizer) -> None: scaler.unscale_(optimizer) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py index 32ca7da815213..fcb28ad33274c 100644 --- a/src/lightning/fabric/plugins/precision/half.py +++ b/src/lightning/fabric/plugins/precision/half.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, ContextManager, Literal +from contextlib import AbstractContextManager +from typing import Any, Literal import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -42,15 +43,15 @@ def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index fbff54f8e3595..1dfab2a7bc649 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import Any, ContextManager, Dict, Literal, Optional, Union +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Literal, Optional, Union from torch import Tensor from torch.nn import Module @@ -53,11 +53,11 @@ def convert_module(self, module: Module) -> Module: """ return module - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: """Controls how tensors get created (device, dtype).""" return nullcontext() - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: """Instantiate module parameters or tensors in the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. @@ -65,7 +65,7 @@ def module_init_context(self) -> ContextManager: """ return nullcontext() - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" return nullcontext() @@ -135,7 +135,7 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS: def unscale_gradients(self, optimizer: Optimizer) -> None: return - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate precision plugin state_dict. Returns: @@ -144,7 +144,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict. diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index cb5296b21fc39..c3ef84a453e73 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from contextlib import ExitStack -from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union +from collections.abc import Mapping +from contextlib import AbstractContextManager, ExitStack +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -106,11 +107,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.weights_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: dtype_ctx = self.tensor_init_context() stack = ExitStack() if self.replace_layers: @@ -125,7 +126,7 @@ def module_init_context(self) -> ContextManager: return stack @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: dtype_ctx = _DtypeContextManager(self.weights_dtype) fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) import transformer_engine.pytorch as te diff --git a/src/lightning/fabric/plugins/precision/utils.py b/src/lightning/fabric/plugins/precision/utils.py index 887dbc937a1f6..8362384cb1042 100644 --- a/src/lightning/fabric/plugins/precision/utils.py +++ b/src/lightning/fabric/plugins/precision/utils.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Mapping, Type, Union +from collections.abc import Mapping +from typing import Any, Union import torch from torch import Tensor @@ -43,7 +44,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: class _ClassReplacementContextManager: """A context manager to monkeypatch classes.""" - def __init__(self, mapping: Mapping[str, Type]) -> None: + def __init__(self, mapping: Mapping[str, type]) -> None: self._mapping = mapping self._originals = {} self._modules = {} diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index c38780655ce6e..ce47e4e403c34 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import AbstractContextManager, nullcontext from datetime import timedelta -from typing import Any, ContextManager, Dict, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import torch import torch.distributed @@ -55,7 +55,7 @@ class DDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, @@ -99,7 +99,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -171,14 +171,14 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj[0] @override - def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: if isinstance(module, DistributedDataParallel): module = module.module return super().get_module_state_dict(module) @override def load_module_state_dict( - self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: if isinstance(module, DistributedDataParallel): module = module.module @@ -225,13 +225,13 @@ def _set_world_ranks(self) -> None: # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - def _determine_ddp_device_ids(self) -> Optional[List[int]]: + def _determine_ddp_device_ids(self) -> Optional[list[int]]: return None if self.root_device.type == "cpu" else [self.root_device.index] class _DDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index e71b8e2db3d58..03d90cd5df057 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -16,10 +16,11 @@ import logging import os import platform -from contextlib import ExitStack +from collections.abc import Mapping +from contextlib import AbstractContextManager, ExitStack from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -80,9 +81,9 @@ def __init__( reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, logging_batch_size_per_gpu: Optional[int] = None, - config: Optional[Union[_PATH, Dict[str, Any]]] = None, + config: Optional[Union[_PATH, dict[str, Any]]] = None, logging_level: int = logging.WARN, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, initial_scale_power: int = 16, @@ -302,7 +303,7 @@ def zero_stage_3(self) -> bool: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @property @@ -311,8 +312,8 @@ def model(self) -> "DeepSpeedEngine": @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple["DeepSpeedEngine", List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple["DeepSpeedEngine", list[Optimizer]]: """Set up a model and multiple optimizers together. Currently, only a single optimizer is supported. @@ -352,7 +353,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: raise NotImplementedError(self._err_msg_joint_setup_required()) @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: if self.zero_stage_3 and empty_init is False: raise NotImplementedError( f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." @@ -365,7 +366,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: # Current limitation in Fabric: The config needs to be fully determined at the time of calling the context # manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here. @@ -382,9 +383,9 @@ def module_sharded_context(self) -> ContextManager: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state in a checkpoint directory. @@ -447,9 +448,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. Args: @@ -595,7 +596,7 @@ def _initialize_engine( self, model: Module, optimizer: Optional[Optimizer] = None, - ) -> Tuple["DeepSpeedEngine", Optimizer]: + ) -> tuple["DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. This calls ``deepspeed.initialize`` internally. @@ -714,7 +715,7 @@ def _create_default_config( overlap_events: bool, thread_count: int, **zero_kwargs: Any, - ) -> Dict: + ) -> dict: cfg = { "activation_checkpointing": { "partition_activations": partition_activations, @@ -769,9 +770,9 @@ def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None: import deepspeed def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it @@ -802,7 +803,7 @@ def load(module: torch.nn.Module, prefix: str = "") -> None: load(module, prefix="") - def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: + def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -817,14 +818,14 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option return config -def _get_deepspeed_engines_from_state(state: Dict[str, Any]) -> List["DeepSpeedEngine"]: +def _get_deepspeed_engines_from_state(state: dict[str, Any]) -> list["DeepSpeedEngine"]: from deepspeed import DeepSpeedEngine modules = chain(*(module.modules() for module in state.values() if isinstance(module, Module))) return [engine for engine in modules if isinstance(engine, DeepSpeedEngine)] -def _validate_state_keys(state: Dict[str, Any]) -> None: +def _validate_state_keys(state: dict[str, Any]) -> None: # DeepSpeed merges the client state into its internal engine state when saving, but it does not check for # colliding keys from the user. We explicitly check it here: deepspeed_internal_keys = { @@ -851,7 +852,7 @@ def _validate_state_keys(state: Dict[str, Any]) -> None: ) -def _validate_device_index_selection(parallel_devices: List[torch.device]) -> None: +def _validate_device_index_selection(parallel_devices: list[torch.device]) -> None: selected_device_indices = [device.index for device in parallel_devices] expected_device_indices = list(range(len(parallel_devices))) if selected_device_indices != expected_device_indices: @@ -903,7 +904,7 @@ def _validate_checkpoint_directory(path: _PATH) -> None: def _format_precision_config( - config: Dict[str, Any], + config: dict[str, Any], precision: str, loss_scale: float, loss_scale_window: int, diff --git a/src/lightning/fabric/strategies/dp.py b/src/lightning/fabric/strategies/dp.py index 2fed307af5129..f407040649c54 100644 --- a/src/lightning/fabric/strategies/dp.py +++ b/src/lightning/fabric/strategies/dp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor @@ -35,7 +35,7 @@ class DataParallelStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, ): @@ -95,14 +95,14 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: return decision @override - def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: if isinstance(module, DataParallel): module = module.module return super().get_module_state_dict(module) @override def load_module_state_dict( - self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: if isinstance(module, DataParallel): module = module.module diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index e7fdd29f6287f..9dd5b2c62d4c9 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -13,7 +13,8 @@ # limitations under the License. import shutil import warnings -from contextlib import ExitStack, nullcontext +from collections.abc import Generator +from contextlib import AbstractContextManager, ExitStack, nullcontext from datetime import timedelta from functools import partial from pathlib import Path @@ -21,15 +22,8 @@ TYPE_CHECKING, Any, Callable, - ContextManager, - Dict, - Generator, - List, Literal, Optional, - Set, - Tuple, - Type, Union, ) @@ -78,7 +72,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy - _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] + _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] @@ -143,7 +137,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, @@ -151,11 +145,11 @@ def __init__( cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "sharded", - device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, + device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -216,7 +210,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -267,8 +261,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple[Module, list[Optimizer]]: """Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer.""" use_orig_params = self._fsdp_kwargs.get("use_orig_params") @@ -340,7 +334,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -354,7 +348,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.wrap import enable_wrap @@ -419,9 +413,9 @@ def clip_gradients_norm( def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -473,8 +467,8 @@ def save_checkpoint( # replace the modules and optimizer objects in the state with their local state dict # and separate the user's metadata - converted_state: Dict[str, Any] = {} - metadata: Dict[str, Any] = {} + converted_state: dict[str, Any] = {} + metadata: dict[str, Any] = {} with state_dict_ctx: for key, obj in state.items(): converted: Any @@ -499,7 +493,7 @@ def save_checkpoint( shutil.rmtree(path) state_dict_ctx = _get_full_state_dict_context(module, world_size=self.world_size) - full_state: Dict[str, Any] = {} + full_state: dict[str, Any] = {} with state_dict_ctx: for key, obj in state.items(): if isinstance(obj, Module): @@ -519,9 +513,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: raise ValueError( @@ -683,9 +677,9 @@ def _set_world_ranks(self) -> None: def _activation_checkpointing_kwargs( - activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]], + activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]], activation_checkpointing_policy: Optional["_POLICY"], -) -> Dict: +) -> dict: if activation_checkpointing is None and activation_checkpointing_policy is None: return {} if activation_checkpointing is not None and activation_checkpointing_policy is not None: @@ -707,7 +701,7 @@ def _activation_checkpointing_kwargs( return {"auto_wrap_policy": activation_checkpointing_policy} -def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: +def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: if policy is None: return kwargs if isinstance(policy, set): @@ -719,7 +713,7 @@ def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: return kwargs -def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: Dict) -> None: +def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: dict) -> None: if not activation_checkpointing_kwargs: return @@ -745,7 +739,7 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa class _FSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch.distributed.fsdp.FullyShardedDataParallel` wrapper.""" if not enabled: @@ -768,7 +762,7 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) -def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY", kwargs: Dict) -> "ShardingStrategy": +def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY", kwargs: dict) -> "ShardingStrategy": from torch.distributed.fsdp import ShardingStrategy if kwargs.get("process_group") is not None and kwargs.get("device_mesh") is not None: @@ -858,7 +852,7 @@ def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) metric.to(device) # `.to()` is in-place -def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> None: +def _distributed_checkpoint_save(converted_state: dict[str, Any], path: Path) -> None: if _TORCH_GREATER_EQUAL_2_3: from torch.distributed.checkpoint import save @@ -877,7 +871,7 @@ def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> save(converted_state, writer) -def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None: +def _distributed_checkpoint_load(module_state: dict[str, Any], path: Path) -> None: if _TORCH_GREATER_EQUAL_2_3: from torch.distributed.checkpoint import load diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 14a063f28f336..d9b96dca5471d 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from multiprocessing.queues import SimpleQueue from textwrap import dedent -from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch import torch.backends.cudnn @@ -167,7 +167,7 @@ class _GlobalStateSnapshot: use_deterministic_algorithms: bool use_deterministic_algorithms_warn_only: bool cudnn_benchmark: bool - rng_states: Dict[str, Any] + rng_states: dict[str, Any] @classmethod def capture(cls) -> "_GlobalStateSnapshot": diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py index 63ae8b0beee4b..a28fe971c7ac4 100644 --- a/src/lightning/fabric/strategies/launchers/subprocess_script.py +++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py @@ -18,7 +18,8 @@ import sys import threading import time -from typing import Any, Callable, List, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Any, Callable, Optional from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -80,7 +81,7 @@ def __init__( self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher + self.procs: list[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher @property @override @@ -162,7 +163,7 @@ def _basic_subprocess_cmd() -> Sequence[str]: return [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] -def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]: +def _hydra_subprocess_cmd(local_rank: int) -> tuple[Sequence[str], str]: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path @@ -183,13 +184,13 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]: return command, cwd -def _launch_process_observer(child_processes: List[subprocess.Popen]) -> None: +def _launch_process_observer(child_processes: list[subprocess.Popen]) -> None: """Launches a thread that runs along the main process and monitors the health of all processes.""" _ChildProcessObserver(child_processes=child_processes, main_pid=os.getpid()).start() class _ChildProcessObserver(threading.Thread): - def __init__(self, main_pid: int, child_processes: List[subprocess.Popen], sleep_period: int = 5) -> None: + def __init__(self, main_pid: int, child_processes: list[subprocess.Popen], sleep_period: int = 5) -> None: super().__init__(daemon=True, name="child-process-observer") # thread stops if the main process exits self._main_pid = main_pid self._child_processes = child_processes diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 86b93d35e66f3..ad1fc19074d06 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -13,10 +13,11 @@ # limitations under the License. import itertools import shutil -from contextlib import ExitStack +from collections.abc import Generator +from contextlib import AbstractContextManager, ExitStack from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -144,7 +145,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: assert self.device_mesh is not None data_parallel_mesh = self.device_mesh["data_parallel"] return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @@ -194,7 +195,7 @@ def module_to_device(self, module: Module) -> None: pass @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() stack = ExitStack() if empty_init: @@ -234,9 +235,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state to a checkpoint on disk. @@ -272,9 +273,9 @@ def save_checkpoint( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects.""" if not state: raise ValueError( @@ -318,12 +319,12 @@ def _set_world_ranks(self) -> None: class _ParallelBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the FSDP2 modules.""" return _FSDPNoSync(module=module, enabled=enabled) -class _FSDPNoSync(ContextManager): +class _FSDPNoSync(AbstractContextManager): def __init__(self, module: Module, enabled: bool) -> None: self._module = module self._enabled = enabled @@ -344,10 +345,10 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def _save_checkpoint( path: Path, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], full_state_dict: bool, rank: int, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: if path.is_dir() and full_state_dict and not _is_sharded_checkpoint(path): raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") @@ -373,8 +374,8 @@ def _save_checkpoint( # replace the modules and optimizer objects in the state with their local state dict # and separate the user's metadata - converted_state: Dict[str, Any] = {} - metadata: Dict[str, Any] = {} + converted_state: dict[str, Any] = {} + metadata: dict[str, Any] = {} for key, obj in state.items(): converted: Any if isinstance(obj, Module): @@ -405,10 +406,10 @@ def _save_checkpoint( def _load_checkpoint( path: Path, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], strict: bool = True, optimizer_states_from_list: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: from torch.distributed.checkpoint.state_dict import ( StateDictOptions, get_model_state_dict, @@ -537,7 +538,7 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int def _load_raw_module_state( - state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True + state_dict: dict[str, Any], module: Module, world_size: int = 1, strict: bool = True ) -> None: """Loads the state dict into the module by gathering all weights first and then and writing back to each shard.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -583,7 +584,7 @@ def _named_parameters_and_buffers_to_load(module: Module) -> Generator: yield param_name, param -def _rekey_optimizer_state_if_needed(optimizer_state_dict: Dict[str, Any], module: Module) -> Dict[str, Any]: +def _rekey_optimizer_state_if_needed(optimizer_state_dict: dict[str, Any], module: Module) -> dict[str, Any]: """Handles the case where the optimizer state is saved from a normal optimizer and converts the keys to parameter names.""" from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/src/lightning/fabric/strategies/parallel.py b/src/lightning/fabric/strategies/parallel.py index a12a0611c90ab..d9bc1a03d1bb5 100644 --- a/src/lightning/fabric/strategies/parallel.py +++ b/src/lightning/fabric/strategies/parallel.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -33,7 +33,7 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, @@ -64,15 +64,15 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[List[torch.device]]: + def parallel_devices(self) -> Optional[list[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: + def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: """Arguments for the ``DistributedSampler``. If this method is not defined, or it returns ``None``, then the ``DistributedSampler`` will not be used. diff --git a/src/lightning/fabric/strategies/registry.py b/src/lightning/fabric/strategies/registry.py index d7899584b5a88..d2376463c3111 100644 --- a/src/lightning/fabric/strategies/registry.py +++ b/src/lightning/fabric/strategies/registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional from typing_extensions import override @@ -65,7 +65,7 @@ def register( if name in self and not override: raise ValueError(f"'{name}' is already present in the registry. HINT: Use `override=True`.") - data: Dict[str, Any] = {} + data: dict[str, Any] = {} data["description"] = description if description is not None else "" data["init_params"] = init_params @@ -104,7 +104,7 @@ def remove(self, name: str) -> None: """Removes the registered strategy by name.""" self.pop(name) - def available_strategies(self) -> List: + def available_strategies(self) -> list: """Returns a list of registered strategies.""" return list(self.keys()) diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 6bfed6a270b68..4daad9b954b2f 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -13,8 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod -from contextlib import ExitStack -from typing import Any, Callable, ContextManager, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from collections.abc import Iterable +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Callable, Optional, TypeVar, Union import torch from torch import Tensor @@ -117,7 +118,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: """ return dataloader - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: """Controls how tensors get created (device, dtype).""" precision_init_ctx = self.precision.tensor_init_context() stack = ExitStack() @@ -125,7 +126,7 @@ def tensor_init_context(self) -> ContextManager: stack.enter_context(precision_init_ctx) return stack - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: """A context manager wrapping the model instantiation. Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other @@ -144,8 +145,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple[Module, list[Optimizer]]: """Set up a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -256,9 +257,9 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. @@ -276,17 +277,17 @@ def save_checkpoint( if self.is_global_zero: self.checkpoint_io.save_checkpoint(checkpoint=state, path=path, storage_options=storage_options) - def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]: + def get_module_state_dict(self, module: Module) -> dict[str, Union[Any, Tensor]]: """Returns model state.""" return module.state_dict() def load_module_state_dict( - self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True + self, module: Module, state_dict: dict[str, Union[Any, Tensor]], strict: bool = True ) -> None: """Loads the given state into the model.""" module.load_state_dict(state_dict, strict=strict) - def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def get_optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. @@ -304,9 +305,9 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Load the contents from a checkpoint and restore the state of the given objects. Args: @@ -394,9 +395,9 @@ def _err_msg_joint_setup_required(self) -> str: ) def _convert_stateful_objects_in_state( - self, state: Dict[str, Union[Module, Optimizer, Any]], filter: Dict[str, Callable[[str, Any], bool]] - ) -> Dict[str, Any]: - converted_state: Dict[str, Any] = {} + self, state: dict[str, Union[Module, Optimizer, Any]], filter: dict[str, Callable[[str, Any], bool]] + ) -> dict[str, Any]: + converted_state: dict[str, Any] = {} for key, obj in state.items(): # convert the state if isinstance(obj, Module): @@ -421,7 +422,7 @@ class _BackwardSyncControl(ABC): """ @abstractmethod - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks the synchronization of gradients during the backward pass. This is a context manager. It is only effective if it wraps a call to `.backward()`. @@ -433,7 +434,7 @@ class _Sharded(ABC): """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters.""" @abstractmethod - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of parameters on creation. @@ -454,7 +455,7 @@ def _validate_keys_for_strict_loading( def _apply_filter( - key: str, filter: Dict[str, Callable[[str, Any], bool]], source_dict: object, target_dict: Dict[str, Any] + key: str, filter: dict[str, Callable[[str, Any], bool]], source_dict: object, target_dict: dict[str, Any] ) -> None: # filter out if necessary if key in filter and isinstance(source_dict, dict): diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index 28d65558cae90..3b2e10e87b0a7 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch import Tensor @@ -43,7 +43,7 @@ class XLAStrategy(ParallelStrategy): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, sync_module_states: bool = True, @@ -276,9 +276,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state as a checkpoint file. diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index e4c080d8110db..935ef72713bcc 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -from contextlib import ExitStack, nullcontext +from contextlib import AbstractContextManager, ExitStack, nullcontext from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Set, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import torch from torch import Tensor @@ -46,7 +46,7 @@ if TYPE_CHECKING: from torch_xla.distributed.parallel_loader import MpDeviceLoader -_POLICY_SET = Set[Type[Module]] +_POLICY_SET = set[type[Module]] _POLICY = Union[_POLICY_SET, Callable[[Module, bool, int], bool]] @@ -83,7 +83,7 @@ class XLAFSDPStrategy(ParallelStrategy, _Sharded): def __init__( self, accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[XLACheckpointIO] = None, precision: Optional[XLAPrecision] = None, auto_wrap_policy: Optional[_POLICY] = None, @@ -196,8 +196,8 @@ def setup_environment(self) -> None: @override def setup_module_and_optimizers( - self, module: Module, optimizers: List[Optimizer] - ) -> Tuple[Module, List[Optimizer]]: + self, module: Module, optimizers: list[Optimizer] + ) -> tuple[Module, list[Optimizer]]: """Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup.""" raise NotImplementedError( f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)." @@ -225,7 +225,7 @@ def setup_module(self, module: Module) -> Module: def module_to_device(self, module: Module) -> None: pass - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() stack = ExitStack() @@ -235,7 +235,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag return stack @override - def module_sharded_context(self) -> ContextManager: + def module_sharded_context(self) -> AbstractContextManager: return nullcontext() @override @@ -408,9 +408,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: def save_checkpoint( self, path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, + filter: Optional[dict[str, Callable[[str, Any], bool]]] = None, ) -> None: """Save model, optimizer, and other state in the provided checkpoint directory. @@ -483,13 +483,13 @@ def save_checkpoint( def _save_checkpoint_shard( self, path: Path, - state: Dict[str, Union[Module, Optimizer, Any]], + state: dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any], - filter: Optional[Dict[str, Callable[[str, Any], bool]]], + filter: Optional[dict[str, Callable[[str, Any], bool]]], ) -> None: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as XLAFSDP - converted_state: Dict[str, Any] = {} + converted_state: dict[str, Any] = {} for key, obj in state.items(): # convert the state if isinstance(obj, Module) and isinstance(obj, XLAFSDP): @@ -512,9 +512,9 @@ def _save_checkpoint_shard( def load_checkpoint( self, path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, + state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None, strict: bool = True, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Given a folder, load the contents from a checkpoint and restore the state of the given objects. The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a @@ -617,7 +617,7 @@ def load_checkpoint( def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: strategy_registry.register("xla_fsdp", cls, description=cls.__name__) - def _parse_fsdp_kwargs(self) -> Dict: + def _parse_fsdp_kwargs(self) -> dict: # this needs to be delayed because `self.precision` isn't available at init kwargs = self._fsdp_kwargs.copy() precision = self.precision @@ -629,7 +629,7 @@ def _parse_fsdp_kwargs(self) -> Dict: return _activation_checkpointing_kwargs(self._activation_checkpointing_policy, kwargs) -def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: +def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: dict) -> dict: if policy is None: return kwargs if isinstance(policy, set): @@ -649,7 +649,7 @@ def _activation_checkpointing_auto_wrapper(policy: _POLICY_SET, module: Module, return XLAFSDP(module, *args, **kwargs) -def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict) -> Dict: +def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: dict) -> dict: if not policy: return kwargs if "auto_wrapper_callable" in kwargs: @@ -668,7 +668,7 @@ def _activation_checkpointing_kwargs(policy: Optional[_POLICY_SET], kwargs: Dict class _XLAFSDPBackwardSyncControl(_BackwardSyncControl): @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: + def no_backward_sync(self, module: Module, enabled: bool) -> AbstractContextManager: """Blocks gradient synchronization inside the :class:`~torch_xla.distributed.fsdp.XlaFullyShardedDataParallel` wrapper.""" if not enabled: diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py index d43565f494d3c..35693a5fcf1fb 100644 --- a/src/lightning/fabric/utilities/apply_func.py +++ b/src/lightning/fabric/utilities/apply_func.py @@ -15,7 +15,7 @@ from abc import ABC from functools import partial -from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -34,7 +34,7 @@ def _from_numpy(value: "np.ndarray", device: _DEVICE) -> Tensor: return torch.from_numpy(value).to(device) -CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any, Any], Tensor]]] = [ +CONVERSION_DTYPES: list[tuple[Any, Callable[[Any, Any], Tensor]]] = [ # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group (bool, partial(torch.tensor, dtype=torch.uint8)), (int, partial(torch.tensor, dtype=torch.int)), diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 7ecc9eea501a6..9d0a33afd0b77 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -16,7 +16,7 @@ import io import logging from pathlib import Path -from typing import IO, Any, Dict, Union +from typing import IO, Any, Union import fsspec import fsspec.utils @@ -69,7 +69,7 @@ def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: return fs -def _atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: +def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 1ec0edce38050..ea35d8c3da4a9 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -16,9 +16,10 @@ import inspect import os from collections import OrderedDict +from collections.abc import Generator, Iterable, Sized from contextlib import contextmanager from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Optional, Sized, Tuple, Type, Union +from typing import Any, Callable, Optional, Union from lightning_utilities.core.inheritance import get_all_subclasses from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler @@ -79,7 +80,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable] def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], -) -> Tuple[Tuple[Any], Dict[str, Any]]: +) -> tuple[tuple[Any], dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -172,7 +173,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], -) -> Dict[str, Any]: +) -> dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation.""" batch_sampler = getattr(dataloader, "batch_sampler") @@ -249,7 +250,7 @@ def _auto_add_worker_init_fn(dataloader: object, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: +def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[type] = None, **kwargs: Any) -> Any: constructor = type(orig_object) if explicit_cls is None else explicit_cls try: @@ -355,7 +356,7 @@ def wrapper(obj: Any, *args: Any) -> None: @contextmanager -def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: +def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. @@ -366,8 +367,8 @@ def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = # Check that __init__ belongs to the class # https://stackoverflow.com/a/5253424 if "__init__" in cls.__dict__: - cls.__old__init__ = cls.__init__ - cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) + cls.__old__init__ = cls.__init__ # type: ignore[misc] + cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) # type: ignore[misc] # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls` @@ -389,11 +390,11 @@ def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = def _replace_value_in_saved_args( replace_key: str, replace_value: Any, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - default_kwargs: Dict[str, Any], - arg_names: Tuple[str, ...], -) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: + args: tuple[Any, ...], + kwargs: dict[str, Any], + default_kwargs: dict[str, Any], + arg_names: tuple[str, ...], +) -> tuple[bool, tuple[Any, ...], dict[str, Any]]: """Tries to replace an argument value in a saved list of args and kwargs. Returns a tuple indicating success of the operation and modified saved args and kwargs @@ -420,7 +421,7 @@ def _set_sampler_epoch(dataloader: object, epoch: int) -> None: """ # cannot use a set because samplers might be unhashable: use a dict based on the id to drop duplicates - objects: Dict[int, Any] = {} + objects: dict[int, Any] = {} # check dataloader.sampler if (sampler := getattr(dataloader, "sampler", None)) is not None: objects[id(sampler)] = sampler @@ -458,7 +459,7 @@ def _num_cpus_available() -> int: return 1 if cpu_count is None else cpu_count -class AttributeDict(Dict): +class AttributeDict(dict): """A container to store state variables of your program. This is a drop-in replacement for a Python dictionary, with the additional functionality to access and modify keys diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index 9f06dc50cfbef..ff5a0949e4207 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union import torch from torch.nn import Module @@ -20,7 +20,7 @@ class _DeviceDtypeModuleMixin(Module): - __jit_unused_properties__: List[str] = ["device", "dtype"] + __jit_unused_properties__: list[str] = ["device", "dtype"] def __init__(self) -> None: super().__init__() diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 16965d944caec..ff5bebd9b4516 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, MutableSequence, Optional, Tuple, Union +from collections.abc import MutableSequence +from typing import Optional, Union import torch @@ -19,7 +20,7 @@ from lightning.fabric.utilities.types import _DEVICE -def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: +def _determine_root_gpu_device(gpus: list[_DEVICE]) -> Optional[_DEVICE]: """ Args: gpus: Non-empty list of ints representing which GPUs to use @@ -46,10 +47,10 @@ def _determine_root_gpu_device(gpus: List[_DEVICE]) -> Optional[_DEVICE]: def _parse_gpu_ids( - gpus: Optional[Union[int, str, List[int]]], + gpus: Optional[Union[int, str, list[int]]], include_cuda: bool = False, include_mps: bool = False, -) -> Optional[List[int]]: +) -> Optional[list[int]]: """Parses the GPU IDs given in the format as accepted by the :class:`~lightning.pytorch.trainer.trainer.Trainer`. Args: @@ -102,7 +103,7 @@ def _parse_gpu_ids( return _sanitize_gpu_ids(gpus, include_cuda=include_cuda, include_mps=include_mps) -def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: +def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[int, list[int]]: if not isinstance(s, str): return s if s == "-1": @@ -112,7 +113,7 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in return int(s.strip()) -def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: bool = False) -> list[int]: """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the GPUs is not available. @@ -139,8 +140,8 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: def _normalize_parse_gpu_input_to_list( - gpus: Union[int, List[int], Tuple[int, ...]], include_cuda: bool, include_mps: bool -) -> Optional[List[int]]: + gpus: Union[int, list[int], tuple[int, ...]], include_cuda: bool, include_mps: bool +) -> Optional[list[int]]: assert gpus is not None if isinstance(gpus, (MutableSequence, tuple)): return list(gpus) @@ -154,7 +155,7 @@ def _normalize_parse_gpu_input_to_list( return list(range(gpus)) -def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> List[int]: +def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False) -> list[int]: """ Returns: A list of all available GPUs @@ -167,7 +168,7 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals return cuda_gpus + mps_gpus -def _check_unique(device_ids: List[int]) -> None: +def _check_unique(device_ids: list[int]) -> None: """Checks that the device_ids are unique. Args: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 0e6c52dfb09b9..ec4eb261f2d3e 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -4,10 +4,11 @@ import os import signal import time +from collections.abc import Iterable, Iterator, Sized from contextlib import nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sized, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.nn.functional as F @@ -99,7 +100,7 @@ def is_shared_filesystem(strategy: "Strategy", path: Optional[_PATH] = None, tim return all_found -def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: +def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> list[Tensor]: """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes. Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case @@ -153,7 +154,7 @@ def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Ten return gathered_result -def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: +def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> list[Tensor]: gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) return gathered_result @@ -345,7 +346,7 @@ def __init__(self, sampler: Union[Sampler, Iterable]) -> None: ) self._sampler = sampler # defer materializing an iterator until it is necessary - self._sampler_list: Optional[List[Any]] = None + self._sampler_list: Optional[list[Any]] = None @override def __getitem__(self, index: int) -> Any: diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index c92dfd8c2e82b..2760c6bd227c1 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Callable, Dict, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import torch from torch.nn import Module, Parameter @@ -46,7 +47,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[Dict] = None, + kwargs: Optional[dict] = None, ) -> Any: kwargs = kwargs or {} if not self.enabled: diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index a1c3b6933b2f6..9e158c1677b6d 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -13,10 +13,12 @@ import os import pickle import warnings +from collections import OrderedDict +from collections.abc import Sequence from functools import partial from io import BytesIO from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -134,7 +136,7 @@ def __torch_function__( func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Optional[Dict] = None, + kwargs: Optional[dict] = None, ) -> Any: kwargs = kwargs or {} loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] @@ -219,7 +221,7 @@ def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: def _move_state_into( - source: Dict[str, Any], destination: Dict[str, Union[Any, _Stateful]], keys: Optional[Set[str]] = None + source: dict[str, Any], destination: dict[str, Union[Any, _Stateful]], keys: Optional[set[str]] = None ) -> None: """Takes the state from the source destination and moves it into the destination dictionary. @@ -235,7 +237,7 @@ def _move_state_into( destination[key] = state -def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: +def _load_distributed_checkpoint(checkpoint_folder: Path) -> dict[str, Any]: """Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict. The current implementation assumes that the entire checkpoint fits in CPU memory. @@ -248,7 +250,7 @@ def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from torch.distributed.checkpoint.state_dict_loader import _load_state_dict - checkpoint: Dict[str, Any] = {} + checkpoint: dict[str, Any] = {} _load_state_dict( checkpoint, storage_reader=FileSystemReader(checkpoint_folder), diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index 07b76ad9b04d8..dd2b0a3663fc9 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -15,15 +15,16 @@ import inspect import json from argparse import Namespace +from collections.abc import Mapping, MutableMapping from dataclasses import asdict, is_dataclass -from typing import Any, Dict, Mapping, MutableMapping, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from lightning.fabric.utilities.imports import _NUMPY_AVAILABLE -def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[str, Any]: +def _convert_params(params: Optional[Union[dict[str, Any], Namespace]]) -> dict[str, Any]: """Ensure parameters are a dict or convert to dict if necessary. Args: @@ -43,7 +44,7 @@ def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[ return params -def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: +def _sanitize_callable_params(params: dict[str, Any]) -> dict[str, Any]: """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. Args: @@ -73,7 +74,7 @@ def _sanitize_callable(val: Any) -> Any: return {key: _sanitize_callable(val) for key, val in params.items()} -def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> Dict[str, Any]: +def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> dict[str, Any]: """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. Args: @@ -92,7 +93,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent {'5/a': 123} """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} for k, v in params.items(): new_key = parent_key + delimiter + str(k) if parent_key else str(k) if is_dataclass(v) and not isinstance(v, type): @@ -107,7 +108,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent return result -def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: +def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: """Returns params with non-primitvies converted to strings for logging. >>> import torch @@ -140,7 +141,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: return params -def _convert_json_serializable(params: Dict[str, Any]) -> Dict[str, Any]: +def _convert_json_serializable(params: dict[str, Any]) -> dict[str, Any]: """Convert non-serializable objects in params to string.""" return {k: str(v) if not _is_json_serializable(v) else v for k, v in params.items()} diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index 2c57ec9d1f64a..df83f9b1ca542 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import MutableMapping -from typing import Iterable +from collections.abc import Iterable, MutableMapping from torch import Tensor from torch.optim import Optimizer diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 9ad0f90221429..7d8f6ca17712e 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -15,7 +15,7 @@ from importlib.metadata import entry_points from inspect import getmembers, isclass from types import ModuleType -from typing import Any, List, Type, Union +from typing import Any, Union from lightning_utilities import is_overridden @@ -24,7 +24,7 @@ _log = logging.getLogger(__name__) -def _load_external_callbacks(group: str) -> List[Any]: +def _load_external_callbacks(group: str) -> list[Any]: """Collect external callbacks registered through entry points. The entry points are expected to be functions returning a list of callbacks. @@ -40,10 +40,10 @@ def _load_external_callbacks(group: str) -> List[Any]: entry_points(group=group) if _PYTHON_GREATER_EQUAL_3_10_0 else entry_points().get(group, {}) # type: ignore[arg-type] ) - external_callbacks: List[Any] = [] + external_callbacks: list[Any] = [] for factory in factories: callback_factory = factory.load() - callbacks_list: Union[List[Any], Any] = callback_factory() + callbacks_list: Union[list[Any], Any] = callback_factory() callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list if callbacks_list: _log.info( @@ -54,7 +54,7 @@ def _load_external_callbacks(group: str) -> List[Any]: return external_callbacks -def _register_classes(registry: Any, method: str, module: ModuleType, parent: Type[object]) -> None: +def _register_classes(registry: Any, method: str, module: ModuleType, parent: type[object]) -> None: for _, member in getmembers(module, isclass): if issubclass(member, parent) and is_overridden(method, member, parent): register_fn = getattr(member, method) diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index a2d627828a77e..dac26f32f3607 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -3,7 +3,7 @@ import random from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch @@ -107,7 +107,7 @@ def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: np.random.seed(seed_sequence[3] & 0xFFFFFFFF) # numpy takes 32-bit seed only -def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> List[int]: +def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]: """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG) algorithm.""" # Combine base seed, worker id and rank into a unique 64-bit number @@ -120,7 +120,7 @@ def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, co return seeds -def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: +def _collect_rng_states(include_cuda: bool = True) -> dict[str, Any]: r"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" states = { "torch": torch.get_rng_state(), @@ -135,7 +135,7 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: return states -def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: +def _set_rng_states(rng_state_dict: dict[str, Any]) -> None: r"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current process.""" torch.set_rng_state(rng_state_dict["torch"]) diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index 5dca5990064e8..04c554461c58c 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -2,7 +2,7 @@ import operator import os import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities.core.imports import compare_version @@ -66,7 +66,7 @@ def __init__( self.warmup = warmup self.atol = atol self.rtol = rtol - self.bad_batches: List[int] = [] + self.bad_batches: list[int] = [] self.exclude_batches_path = exclude_batches_path self.finite_only = finite_only @@ -147,7 +147,7 @@ def _update_stats(self, val: torch.Tensor) -> None: self.running_mean.update(val) self.last_val = val - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "last_val": self.last_val.item() if isinstance(self.last_val, torch.Tensor) else self.last_val, "mode": self.mode, @@ -160,7 +160,7 @@ def state_dict(self) -> Dict[str, Any]: "mean": self.running_mean.base_metric.state_dict(), } - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.last_val = state_dict.pop("last_val") self.mode = state_dict.pop("mode") self.warmup = state_dict.pop("warmup") diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6f0513465cab5..6f5d933f9dae3 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -14,7 +14,7 @@ import operator import os import sys -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from lightning_utilities.core.imports import RequirementCache, compare_version @@ -40,7 +40,7 @@ def _runif_reasons( standalone: bool = False, deepspeed: bool = False, dynamo: bool = False, -) -> Tuple[List[str], Dict[str, bool]]: +) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. Args: diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 598a32228ac75..72b33a41f168c 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -13,7 +13,7 @@ # limitations under the License. # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820/composer/callbacks/speed_monitor.py from collections import deque -from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union import torch from typing_extensions import override @@ -24,7 +24,7 @@ from lightning.fabric import Fabric from lightning.fabric.plugins import Precision -_THROUGHPUT_METRICS = Dict[str, Union[int, float]] +_THROUGHPUT_METRICS = dict[str, Union[int, float]] # The API design of this class follows `torchmetrics.Metric` but it doesn't need to be an actual Metric because there's @@ -108,7 +108,7 @@ def __init__( self._batches: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) self._samples: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) self._lengths: _MonotonicWindow[int] = _MonotonicWindow(maxlen=window_size) - self._flops: Deque[int] = deque(maxlen=window_size) + self._flops: deque[int] = deque(maxlen=window_size) def update( self, @@ -302,7 +302,7 @@ def measure_flops( return flop_counter.get_total_flops() -_CUDA_FLOPS: Dict[str, Dict[Union[str, torch.dtype], float]] = { +_CUDA_FLOPS: dict[str, dict[Union[str, torch.dtype], float]] = { # Hopper # source: https://resources.nvidia.com/en-us-tensor-core "h100 nvl": { @@ -648,7 +648,7 @@ def _plugin_to_compute_dtype(plugin: "Precision") -> torch.dtype: T = TypeVar("T", bound=float) -class _MonotonicWindow(List[T]): +class _MonotonicWindow(list[T]): """Custom fixed size list that only supports right-append and ensures that all values increase monotonically.""" def __init__(self, maxlen: int) -> None: diff --git a/src/lightning/fabric/utilities/types.py b/src/lightning/fabric/utilities/types.py index 2e18dc89b05b2..1d7235fa36383 100644 --- a/src/lightning/fabric/utilities/types.py +++ b/src/lightning/fabric/utilities/types.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict +from collections.abc import Iterator from pathlib import Path from typing import ( Any, Callable, - DefaultDict, - Dict, - Iterator, - List, Optional, Protocol, TypeVar, @@ -38,7 +36,7 @@ _PATH = Union[str, Path] _DEVICE = Union[torch.device, str, int] _MAP_LOCATION_TYPE = Optional[ - Union[_DEVICE, Callable[[UntypedStorage, str], Optional[UntypedStorage]], Dict[_DEVICE, _DEVICE]] + Union[_DEVICE, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[_DEVICE, _DEVICE]] ] _PARAMETERS = Iterator[torch.nn.Parameter] @@ -57,9 +55,9 @@ class _Stateful(Protocol[_DictKey]): """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - def state_dict(self) -> Dict[_DictKey, Any]: ... + def state_dict(self) -> dict[_DictKey, Any]: ... - def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: ... + def load_state_dict(self, state_dict: dict[_DictKey, Any]) -> None: ... @runtime_checkable @@ -86,10 +84,10 @@ def step(self, closure: Optional[Callable[[], float]] = ...) -> Optional[float]: class Optimizable(Steppable, Protocol): """To structurally type ``optimizer``""" - param_groups: List[Dict[Any, Any]] - defaults: Dict[Any, Any] - state: DefaultDict[Tensor, Any] + param_groups: list[dict[Any, Any]] + defaults: dict[Any, Any] + state: defaultdict[Tensor, Any] - def state_dict(self) -> Dict[str, Dict[Any, Any]]: ... + def state_dict(self) -> dict[str, dict[Any, Any]]: ... - def load_state_dict(self, state_dict: Dict[str, Dict[Any, Any]]) -> None: ... + def load_state_dict(self, state_dict: dict[str, dict[Any, Any]]) -> None: ... diff --git a/src/lightning/fabric/utilities/warnings.py b/src/lightning/fabric/utilities/warnings.py index 62e5f5fc2ff17..b62bece384e32 100644 --- a/src/lightning/fabric/utilities/warnings.py +++ b/src/lightning/fabric/utilities/warnings.py @@ -15,7 +15,7 @@ import warnings from pathlib import Path -from typing import Optional, Type, Union +from typing import Optional, Union from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning @@ -38,7 +38,7 @@ def disable_possible_user_warnings(module: str = "") -> None: def _custom_format_warning( - message: Union[Warning, str], category: Type[Warning], filename: str, lineno: int, line: Optional[str] = None + message: Union[Warning, str], category: type[Warning], filename: str, lineno: int, line: Optional[str] = None ) -> str: """Custom formatting that avoids an extra line in case warnings are emitted from the `rank_zero`-functions.""" if _is_path_in_lightning(Path(filename)): diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index c57f1974a6bba..b593c9f22ed23 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,19 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections.abc import Generator, Iterator, Mapping from copy import deepcopy from functools import partial, wraps from types import MethodType from typing import ( Any, Callable, - Dict, - Generator, - Iterator, - List, - Mapping, Optional, - Tuple, TypeVar, Union, overload, @@ -48,14 +43,14 @@ from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.fabric.utilities.types import Optimizable -T_destination = TypeVar("T_destination", bound=Dict[str, Any]) +T_destination = TypeVar("T_destination", bound=dict[str, Any]) _LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step") _in_fabric_backward: bool = False class _FabricOptimizer: - def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None: + def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[list[Callable]] = None) -> None: """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer step calls to the strategy. @@ -76,10 +71,10 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional def optimizer(self) -> Optimizer: return self._optimizer - def state_dict(self) -> Dict[str, Tensor]: + def state_dict(self) -> dict[str, Tensor]: return self._strategy.get_optimizer_state(self.optimizer) - def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None: + def load_state_dict(self, state_dict: dict[str, Tensor]) -> None: self.optimizer.load_state_dict(state_dict) def step(self, closure: Optional[Callable] = None) -> Any: @@ -149,12 +144,12 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... @overload - def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... + def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]: ... @override def state_dict( self, destination: Optional[T_destination] = None, prefix: str = "", keep_vars: bool = False - ) -> Optional[Dict[str, Any]]: + ) -> Optional[dict[str, Any]]: return self._original_module.state_dict( destination=destination, # type: ignore[type-var] prefix=prefix, @@ -350,7 +345,7 @@ def _unwrap( return apply_to_collection(collection, dtype=tuple(types), function=_unwrap) -def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]: +def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> tuple[Union[Any, nn.Module], Optional[dict[str, Any]]]: """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped. Use this function before instance checks against e.g. :class:`_FabricModule`. @@ -366,7 +361,7 @@ def _unwrap_compiled(obj: Union[Any, OptimizedModule]) -> Tuple[Union[Any, nn.Mo return obj, None -def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> OptimizedModule: +def _to_compiled(module: nn.Module, compile_kwargs: dict[str, Any]) -> OptimizedModule: return torch.compile(module, **compile_kwargs) # type: ignore[return-value] diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 0490c2d86431c..9238071178a80 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any, Dict +from typing import Any import lightning.pytorch as pl from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator @@ -34,7 +34,7 @@ def setup(self, trainer: "pl.Trainer") -> None: """ - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get stats for a given device. Args: diff --git a/src/lightning/pytorch/accelerators/cpu.py b/src/lightning/pytorch/accelerators/cpu.py index a85a959ab66e0..525071cbb377f 100644 --- a/src/lightning/pytorch/accelerators/cpu.py +++ b/src/lightning/pytorch/accelerators/cpu.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union +from typing import Any, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -38,7 +38,7 @@ def setup_device(self, device: torch.device) -> None: raise MisconfigurationException(f"Device should be CPU, got {device} instead.") @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get CPU stats from ``psutil`` package.""" return get_cpu_stats() @@ -54,7 +54,7 @@ def parse_devices(devices: Union[int, str]) -> int: @staticmethod @override - def get_parallel_devices(devices: Union[int, str]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" devices = _parse_cpu_cores(devices) return [torch.device("cpu")] * devices @@ -89,7 +89,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _PSUTIL_AVAILABLE = RequirementCache("psutil") -def get_cpu_stats() -> Dict[str, float]: +def get_cpu_stats() -> dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching CPU device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 6df3bc6b468ee..a00b12a85a8dd 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -15,7 +15,7 @@ import os import shutil import subprocess -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from typing_extensions import override @@ -61,7 +61,7 @@ def set_nvidia_flags(local_rank: int) -> None: _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Gets stats for the given GPU device. Args: @@ -83,13 +83,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_cuda=True) @staticmethod @override - def get_parallel_devices(devices: List[int]) -> List[torch.device]: + def get_parallel_devices(devices: list[int]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" return [torch.device("cuda", i) for i in devices] @@ -114,7 +114,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No ) -def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover +def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. Args: diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 6efe6292de624..f7674989cc721 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from typing_extensions import override @@ -43,7 +43,7 @@ def setup_device(self, device: torch.device) -> None: raise MisconfigurationException(f"Device should be MPS, got {device} instead.") @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Get M1 (cpu + gpu) stats from ``psutil`` package.""" return get_device_stats() @@ -53,13 +53,13 @@ def teardown(self) -> None: @staticmethod @override - def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]: + def parse_devices(devices: Union[int, str, list[int]]) -> Optional[list[int]]: """Accelerator device parsing logic.""" return _parse_gpu_ids(devices, include_mps=True) @staticmethod @override - def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]: + def get_parallel_devices(devices: Union[int, str, list[int]]) -> list[torch.device]: """Gets parallel devices for the Accelerator.""" parsed_devices = MPSAccelerator.parse_devices(devices) assert parsed_devices is not None @@ -94,7 +94,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No _SWAP_PERCENT = "M1_swap_percent" -def get_device_stats() -> Dict[str, float]: +def get_device_stats() -> dict[str, float]: if not _PSUTIL_AVAILABLE: raise ModuleNotFoundError( f"Fetching MPS device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}" diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index 01ef7223efcef..10726b505448c 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any from typing_extensions import override @@ -29,7 +29,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator): """ @override - def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: + def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """Gets stats for the given XLA device. Args: diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 9311d49dcd804..3bfb609465a83 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -13,7 +13,7 @@ # limitations under the License. r"""Base class used to build new callbacks.""" -from typing import Any, Dict, Type +from typing import Any from torch import Tensor from torch.optim import Optimizer @@ -41,7 +41,7 @@ def state_key(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_key(self) -> Type["Callback"]: + def _legacy_state_key(self) -> type["Callback"]: """State key for checkpoints saved prior to version 1.5.0.""" return type(self) @@ -229,7 +229,7 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: """Called when any trainer execution is interrupted by an exception.""" - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate callback's ``state_dict``. Returns: @@ -238,7 +238,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload callback state given callback's ``state_dict``. Args: @@ -248,7 +248,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: pass def on_save_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: r"""Called when saving a checkpoint to give you a chance to store anything else you might want to save. @@ -260,7 +260,7 @@ def on_save_checkpoint( """ def on_load_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any] ) -> None: r"""Called when loading a model checkpoint, use to reload state. diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index 64ea47d2897ed..6279dd13be4af 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -19,7 +19,7 @@ """ -from typing import Any, Dict, Optional +from typing import Any, Optional from typing_extensions import override @@ -158,5 +158,5 @@ def on_test_batch_end( self._get_and_log_device_stats(trainer, "on_test_batch_end") -def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: +def _prefix_metric_keys(metrics_dict: dict[str, float], prefix: str, separator: str) -> dict[str, float]: return {prefix + separator + k: v for k, v in metrics_dict.items()} diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index d1212fe8cc2e7..78c4215f9ce23 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -20,7 +20,7 @@ """ import logging -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Optional import torch from torch import Tensor @@ -139,7 +139,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s # validation, then we run after validation instead of on train epoch end self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool: + def _validate_condition_metric(self, logs: dict[str, Tensor]) -> bool: monitor_val = logs.get(self.monitor) error_msg = ( @@ -163,7 +163,7 @@ def monitor_op(self) -> Callable: return self.mode_dict[self.mode] @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "wait_count": self.wait_count, "stopped_epoch": self.stopped_epoch, @@ -172,7 +172,7 @@ def state_dict(self) -> Dict[str, Any]: } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.wait_count = state_dict["wait_count"] self.stopped_epoch = state_dict["stopped_epoch"] self.best_score = state_dict["best_score"] @@ -215,7 +215,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: if reason and self.verbose: self._log_info(trainer, reason, self.log_rank_zero_only) - def _evaluate_stopping_criteria(self, current: Tensor) -> Tuple[bool, Optional[str]]: + def _evaluate_stopping_criteria(self, current: Tensor) -> tuple[bool, Optional[str]]: should_stop = False reason = None if self.check_finite and not torch.isfinite(current): diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 46a90986c091c..356ab221777ae 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -19,7 +19,8 @@ """ import logging -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union +from collections.abc import Generator, Iterable +from typing import Any, Callable, Optional, Union import torch from torch.nn import Module, ModuleDict @@ -85,17 +86,17 @@ class BaseFinetuning(Callback): """ def __init__(self) -> None: - self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {} + self._internal_optimizer_metadata: dict[int, list[dict[str, Any]]] = {} self._restarting = False @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._restarting = True if "internal_optimizer_metadata" in state_dict: # noqa: SIM401 self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"] @@ -116,7 +117,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - self._restarting = False @staticmethod - def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: + def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> list[Module]: """This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves. @@ -215,7 +216,7 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: BaseFinetuning.freeze_module(mod) @staticmethod - def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: + def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> list: """This function is used to exclude any parameter which already exists in this optimizer. Args: @@ -285,7 +286,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s ) @staticmethod - def _apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: + def _apply_mapping_to_param_groups(param_groups: list[dict[str, Any]], mapping: dict) -> list[dict[str, Any]]: output = [] for g in param_groups: # skip params to save memory @@ -299,7 +300,7 @@ def _store( pl_module: "pl.LightningModule", opt_idx: int, num_param_groups: int, - current_param_groups: List[Dict[str, Any]], + current_param_groups: list[dict[str, Any]], ) -> None: mapping = {p: n for n, p in pl_module.named_parameters()} if opt_idx not in self._internal_optimizer_metadata: @@ -387,14 +388,14 @@ def __init__( self.previous_backbone_lr: Optional[float] = None @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "internal_optimizer_metadata": self._internal_optimizer_metadata, "previous_backbone_lr": self.previous_backbone_lr, } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.previous_backbone_lr = state_dict["previous_backbone_lr"] super().load_state_dict(state_dict) diff --git a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py index 20b1df29d18f5..50ddc10f0f661 100644 --- a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py +++ b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py @@ -20,7 +20,7 @@ """ -from typing import Any, Dict +from typing import Any from typing_extensions import override @@ -64,7 +64,7 @@ class GradientAccumulationScheduler(Callback): """ - def __init__(self, scheduling: Dict[int, int]): + def __init__(self, scheduling: dict[int, int]): super().__init__() if not scheduling: # empty dict error diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index 6a94c7ece70a3..ca2b4a866ee50 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -22,7 +22,7 @@ import itertools from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Literal, Optional, Set, Tuple, Type +from typing import Any, Literal, Optional import torch from torch.optim.optimizer import Optimizer @@ -104,9 +104,9 @@ def __init__( self.log_momentum = log_momentum self.log_weight_decay = log_weight_decay - self.lrs: Dict[str, List[float]] = {} - self.last_momentum_values: Dict[str, Optional[List[float]]] = {} - self.last_weight_decay_values: Dict[str, Optional[List[float]]] = {} + self.lrs: dict[str, list[float]] = {} + self.last_momentum_values: dict[str, Optional[list[float]]] = {} + self.last_weight_decay_values: dict[str, Optional[list[float]]] = {} @override def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: @@ -141,7 +141,7 @@ def _check_no_key(key: str) -> bool: ) # Find names for schedulers - names: List[List[str]] = [] + names: list[list[str]] = [] ( sched_hparam_keys, optimizers_with_scheduler, @@ -186,7 +186,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) for logger in trainer.loggers: logger.log_metrics(latest_stat, step=trainer.fit_loop.epoch_loop._batches_that_stepped) - def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: + def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> dict[str, float]: latest_stat = {} ( @@ -219,7 +219,7 @@ def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, floa return latest_stat - def _get_optimizer_stats(self, optimizer: Optimizer, names: List[str]) -> Dict[str, float]: + def _get_optimizer_stats(self, optimizer: Optimizer, names: list[str]) -> dict[str, float]: stats = {} param_groups = optimizer.param_groups use_betas = "betas" in optimizer.defaults @@ -236,12 +236,12 @@ def _get_optimizer_stats(self, optimizer: Optimizer, names: List[str]) -> Dict[s return stats - def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: + def _extract_lr(self, param_group: dict[str, Any], name: str) -> dict[str, Any]: lr = param_group["lr"] self.lrs[name].append(lr) return {name: lr} - def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: + def _remap_keys(self, names: list[list[str]], token: str = "/pg1") -> None: """This function is used the remap the keys if param groups for a given optimizer increased.""" for group_new_names in names: for new_name in group_new_names: @@ -251,7 +251,7 @@ def _remap_keys(self, names: List[List[str]], token: str = "/pg1") -> None: elif new_name not in self.lrs: self.lrs[new_name] = [] - def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: bool) -> Dict[str, float]: + def _extract_momentum(self, param_group: dict[str, list], name: str, use_betas: bool) -> dict[str, float]: if not self.log_momentum: return {} @@ -259,7 +259,7 @@ def _extract_momentum(self, param_group: Dict[str, List], name: str, use_betas: self.last_momentum_values[name] = momentum return {name: momentum} - def _extract_weight_decay(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: + def _extract_weight_decay(self, param_group: dict[str, Any], name: str) -> dict[str, Any]: """Extracts the weight decay statistics from a parameter group.""" if not self.log_weight_decay: return {} @@ -269,14 +269,14 @@ def _extract_weight_decay(self, param_group: Dict[str, Any], name: str) -> Dict[ return {name: weight_decay} def _add_prefix( - self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int] + self, name: str, optimizer_cls: type[Optimizer], seen_optimizer_types: defaultdict[type[Optimizer], int] ) -> str: if optimizer_cls not in seen_optimizer_types: return name count = seen_optimizer_types[optimizer_cls] return name + f"-{count - 1}" if count > 1 else name - def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: int, use_names: bool = True) -> str: + def _add_suffix(self, name: str, param_groups: list[dict], param_group_index: int, use_names: bool = True) -> str: if len(param_groups) > 1: if not use_names: return f"{name}/pg{param_group_index + 1}" @@ -287,7 +287,7 @@ def _add_suffix(self, name: str, param_groups: List[Dict], param_group_index: in return f"{name}/{pg_name}" if pg_name else name return name - def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: + def _duplicate_param_group_names(self, param_groups: list[dict]) -> set[str]: names = [pg.get("name", f"pg{i}") for i, pg in enumerate(param_groups, start=1)] unique = set(names) if len(names) == len(unique): @@ -296,13 +296,13 @@ def _duplicate_param_group_names(self, param_groups: List[Dict]) -> Set[str]: def _find_names_from_schedulers( self, - lr_scheduler_configs: List[LRSchedulerConfig], - ) -> Tuple[List[List[str]], List[Optimizer], DefaultDict[Type[Optimizer], int]]: + lr_scheduler_configs: list[LRSchedulerConfig], + ) -> tuple[list[list[str]], list[Optimizer], defaultdict[type[Optimizer], int]]: # Create unique names in the case we have multiple of the same learning # rate scheduler + multiple parameter groups names = [] - seen_optimizers: List[Optimizer] = [] - seen_optimizer_types: DefaultDict[Type[Optimizer], int] = defaultdict(int) + seen_optimizers: list[Optimizer] = [] + seen_optimizer_types: defaultdict[type[Optimizer], int] = defaultdict(int) for config in lr_scheduler_configs: sch = config.scheduler name = config.name if config.name is not None else "lr-" + sch.optimizer.__class__.__name__ @@ -316,10 +316,10 @@ def _find_names_from_schedulers( def _find_names_from_optimizers( self, - optimizers: List[Any], - seen_optimizers: List[Optimizer], - seen_optimizer_types: DefaultDict[Type[Optimizer], int], - ) -> Tuple[List[List[str]], List[Optimizer]]: + optimizers: list[Any], + seen_optimizers: list[Optimizer], + seen_optimizer_types: defaultdict[type[Optimizer], int], + ) -> tuple[list[list[str]], list[Optimizer]]: names = [] optimizers_without_scheduler = [] @@ -342,10 +342,10 @@ def _check_duplicates_and_update_name( self, optimizer: Optimizer, name: str, - seen_optimizers: List[Optimizer], - seen_optimizer_types: DefaultDict[Type[Optimizer], int], + seen_optimizers: list[Optimizer], + seen_optimizer_types: defaultdict[type[Optimizer], int], lr_scheduler_config: Optional[LRSchedulerConfig], - ) -> List[str]: + ) -> list[str]: seen_optimizers.append(optimizer) optimizer_cls = type(optimizer) if lr_scheduler_config is None or lr_scheduler_config.name is None: diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 9587da0f4600b..85bfb65c0ea6e 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,7 @@ from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Dict, Literal, Optional, Set, Union +from typing import Any, Literal, Optional, Union from weakref import proxy import torch @@ -241,7 +241,7 @@ def __init__( self._last_global_step_saved = 0 # no need to save when no steps were taken self._last_time_checked: Optional[float] = None self.current_score: Optional[Tensor] = None - self.best_k_models: Dict[str, Tensor] = {} + self.best_k_models: dict[str, Tensor] = {} self.kth_best_model_path = "" self.best_model_score: Optional[Tensor] = None self.best_model_path = "" @@ -335,7 +335,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_last_checkpoint(trainer, monitor_candidates) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "monitor": self.monitor, "best_model_score": self.best_model_score, @@ -349,7 +349,7 @@ def state_dict(self) -> Dict[str, Any]: } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath) if self.dirpath == dirpath_from_ckpt: @@ -367,7 +367,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.best_model_path = state_dict["best_model_path"] - def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: if self.save_top_k == 0: return @@ -533,7 +533,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = def _format_checkpoint_name( self, filename: Optional[str], - metrics: Dict[str, Tensor], + metrics: dict[str, Tensor], prefix: str = "", auto_insert_metric_name: bool = True, ) -> str: @@ -567,7 +567,7 @@ def _format_checkpoint_name( return filename def format_checkpoint_name( - self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None ) -> str: """Generate a filename according to the defined template. @@ -637,7 +637,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: return ckpt_path - def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: + def _find_last_checkpoints(self, trainer: "pl.Trainer") -> set[str]: # find all checkpoints in the folder ckpt_path = self.__resolve_ckpt_dir(trainer) last_pattern = rf"^{self.CHECKPOINT_NAME_LAST}(-(\d+))?" @@ -654,7 +654,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") def _get_metric_interpolated_filepath_name( - self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None + self, monitor_candidates: dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None ) -> str: filepath = self.format_checkpoint_name(monitor_candidates) @@ -666,7 +666,7 @@ def _get_metric_interpolated_filepath_name( return filepath - def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: + def _monitor_candidates(self, trainer: "pl.Trainer") -> dict[str, Tensor]: monitor_candidates = deepcopy(trainer.callback_metrics) # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor # or does not exist we overwrite it as it's likely an error @@ -676,7 +676,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step) return monitor_candidates - def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: if not self.save_last: return @@ -697,7 +697,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ if previous and self._should_remove_checkpoint(trainer, previous, filepath): self._remove_checkpoint(trainer, previous) - def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: assert self.monitor current = monitor_candidates.get(self.monitor) if self.check_monitor_top_k(trainer, current): @@ -708,7 +708,7 @@ def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Di step = monitor_candidates["step"] rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: + def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor]) -> None: filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path) # set the best model path before saving because it will be part of the state. previous, self.best_model_path = self.best_model_path, filepath @@ -718,7 +718,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate self._remove_checkpoint(trainer, previous) def _update_best_and_save( - self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] + self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: dict[str, Tensor] ) -> None: k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index 89c31b2cc65e8..03f50d65bf1e9 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -23,7 +23,7 @@ """ import logging -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Union from typing_extensions import override @@ -54,7 +54,7 @@ class ModelSummary(Callback): def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: self._max_depth: int = max_depth - self._summarize_kwargs: Dict[str, Any] = summarize_kwargs + self._summarize_kwargs: dict[str, Any] = summarize_kwargs @override def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -87,11 +87,11 @@ def _summary(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Un @staticmethod def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: Dict[str, int], + total_training_modes: dict[str, int], **summarize_kwargs: Any, ) -> None: summary_table = _format_summary_table( diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index 7f782fb81c091..ce6342c7aa88d 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -18,7 +18,8 @@ Aids in saving predictions """ -from typing import Any, Literal, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Literal, Optional from typing_extensions import override diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 785bf65af4361..7cf6993b4414b 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -176,7 +176,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s def get_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: + ) -> dict[str, Union[int, str, float, dict[str, float]]]: r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. @@ -207,7 +207,7 @@ def get_metrics(self, trainer, model): return {**standard_metrics, **pbar_metrics} -def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: +def get_standard_metrics(trainer: "pl.Trainer") -> dict[str, Union[int, str]]: r"""Returns the standard metrics displayed in the progress bar. Currently, it only includes the version of the experiment when using a logger. @@ -219,7 +219,7 @@ def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: Dictionary with the standard metrics to be displayed in the progress bar. """ - items_dict: Dict[str, Union[int, str]] = {} + items_dict: dict[str, Union[int, str]] = {} if trainer.loggers: from lightning.pytorch.loggers.utilities import _version diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 896de71267835..0a51d99ccb676 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Generator, Optional, Union, cast +from typing import Any, Optional, Union, cast from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -146,15 +147,15 @@ def __init__( metrics_format: str, ): self._trainer = trainer - self._tasks: Dict[Union[int, TaskID], Any] = {} + self._tasks: dict[Union[int, TaskID], Any] = {} self._current_task_id = 0 - self._metrics: Dict[Union[str, Style], Any] = {} + self._metrics: dict[Union[str, Style], Any] = {} self._style = style self._text_delimiter = text_delimiter self._metrics_format = metrics_format super().__init__() - def update(self, metrics: Dict[Any, Any]) -> None: + def update(self, metrics: dict[Any, Any]) -> None: # Called when metrics are ready to be rendered. # This is to prevent render from causing deadlock issues by requesting metrics # in separate threads. @@ -257,7 +258,7 @@ def __init__( refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), - console_kwargs: Optional[Dict[str, Any]] = None, + console_kwargs: Optional[dict[str, Any]] = None, ) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( @@ -642,7 +643,7 @@ def configure_columns(self, trainer: "pl.Trainer") -> list: ProcessingSpeedColumn(style=self.theme.processing_speed), ] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() # both the console and progress object can hold thread lock objects that are not pickleable state["progress"] = None diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index cf9cd71614674..4ef260f00006d 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -15,7 +15,7 @@ import math import os import sys -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -115,7 +115,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool self._predict_progress_bar: Optional[_tqdm] = None self._leave = leave - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: # can't pickle the tqdm objects return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index e83a9de06375c..1517ef6920b0d 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -18,9 +18,10 @@ import inspect import logging +from collections.abc import Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Union import torch.nn.utils.prune as pytorch_prune from lightning_utilities.core.apply_func import apply_to_collection @@ -49,14 +50,14 @@ "random_unstructured": pytorch_prune.RandomUnstructured, } -_PARAM_TUPLE = Tuple[nn.Module, str] +_PARAM_TUPLE = tuple[nn.Module, str] _PARAM_LIST = Sequence[_PARAM_TUPLE] _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) class _LayerRef(TypedDict): data: nn.Module - names: List[Tuple[int, str]] + names: list[tuple[int, str]] class ModelPruning(Callback): @@ -66,7 +67,7 @@ def __init__( self, pruning_fn: Union[Callable, str], parameters_to_prune: _PARAM_LIST = (), - parameter_names: Optional[List[str]] = None, + parameter_names: Optional[list[str]] = None, use_global_unstructured: bool = True, amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, apply_pruning: Union[bool, Callable[[int], bool]] = True, @@ -165,8 +166,8 @@ def __init__( self._resample_parameters = resample_parameters self._prune_on_train_epoch_end = prune_on_train_epoch_end self._parameter_names = parameter_names or self.PARAMETER_NAMES - self._global_kwargs: Dict[str, Any] = {} - self._original_layers: Optional[Dict[int, _LayerRef]] = None + self._global_kwargs: dict[str, Any] = {} + self._original_layers: Optional[dict[int, _LayerRef]] = None self._pruning_method_name: Optional[str] = None for name in self._parameter_names: @@ -310,7 +311,7 @@ def _apply_local_pruning(self, amount: float) -> None: for module, name in self._parameters_to_prune: self.pruning_fn(module, name=name, amount=amount) # type: ignore[call-arg] - def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]: + def _resolve_global_kwargs(self, amount: float) -> dict[str, Any]: self._global_kwargs["amount"] = amount params = set(inspect.signature(self.pruning_fn).parameters) params.discard("self") @@ -322,7 +323,7 @@ def _apply_global_pruning(self, amount: float) -> None: ) @staticmethod - def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: + def _get_pruned_stats(module: nn.Module, name: str) -> tuple[int, int]: attr = f"{name}_mask" if not hasattr(module, attr): return 0, 1 @@ -345,7 +346,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None: @rank_zero_only def _log_sparsity_stats( - self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 + self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0 ) -> None: total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) prev_total_zeros = sum(zeros for zeros, _ in prev) @@ -414,7 +415,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: LightningModule) -> Non rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint") self.make_pruning_permanent(pl_module) - def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> Dict[str, Any]: + def _make_pruning_permanent_on_state_dict(self, pl_module: LightningModule) -> dict[str, Any]: state_dict = pl_module.state_dict() # find the mask and the original weights. @@ -432,7 +433,7 @@ def move_to_cpu(tensor: Tensor) -> Tensor: return apply_to_collection(state_dict, Tensor, move_to_cpu) @override - def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: dict[str, Any]) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint") # manually prune the weights so training can keep going with the same buffers diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index c6c429b4bd2f5..e4027f0dedcb1 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Tuple +from typing import Any from typing_extensions import override @@ -67,11 +67,11 @@ def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: @staticmethod @override def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: Dict[str, int], + total_training_modes: dict[str, int], **summarize_kwargs: Any, ) -> None: from rich import get_console diff --git a/src/lightning/pytorch/callbacks/spike.py b/src/lightning/pytorch/callbacks/spike.py index 725d6f64333a6..b006acd44dcdb 100644 --- a/src/lightning/pytorch/callbacks/spike.py +++ b/src/lightning/pytorch/callbacks/spike.py @@ -1,5 +1,6 @@ import os -from typing import Any, Mapping, Union +from collections.abc import Mapping +from typing import Any, Union import torch diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 737084ced426d..5643a038e00c1 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -17,7 +17,7 @@ """ from copy import deepcopy -from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast +from typing import Any, Callable, Literal, Optional, Union, cast import torch from torch import Tensor, nn @@ -39,7 +39,7 @@ class StochasticWeightAveraging(Callback): def __init__( self, - swa_lrs: Union[float, List[float]], + swa_lrs: Union[float, list[float]], swa_epoch_start: Union[int, float] = 0.8, annealing_epochs: int = 10, annealing_strategy: Literal["cos", "linear"] = "cos", @@ -126,10 +126,10 @@ def __init__( self._average_model: Optional[pl.LightningModule] = None self._initialized = False self._swa_scheduler: Optional[LRScheduler] = None - self._scheduler_state: Optional[Dict] = None + self._scheduler_state: Optional[dict] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 - self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} + self.momenta: dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} self._max_epochs: int @property @@ -331,7 +331,7 @@ def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averag return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, @@ -340,7 +340,7 @@ def state_dict(self) -> Dict[str, Any]: } @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._init_n_averaged = state_dict["n_averaged"] self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a2d73d83184b1..a49610a912e57 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from typing_extensions import override @@ -84,9 +84,9 @@ def __init__( self.batch_size_fn = batch_size_fn self.length_fn = length_fn self.available_flops: Optional[int] = None - self._throughputs: Dict[RunningStage, Throughput] = {} - self._t0s: Dict[RunningStage, float] = {} - self._lengths: Dict[RunningStage, int] = {} + self._throughputs: dict[RunningStage, Throughput] = {} + self._t0s: dict[RunningStage, float] = {} + self._lengths: dict[RunningStage, int] = {} @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py index e1bed4adb9889..b6b74d280427c 100644 --- a/src/lightning/pytorch/callbacks/timer.py +++ b/src/lightning/pytorch/callbacks/timer.py @@ -20,7 +20,7 @@ import re import time from datetime import timedelta -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -83,7 +83,7 @@ class Timer(Callback): def __init__( self, - duration: Optional[Union[str, timedelta, Dict[str, int]]] = None, + duration: Optional[Union[str, timedelta, dict[str, int]]] = None, interval: str = Interval.step, verbose: bool = True, ) -> None: @@ -111,8 +111,8 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} - self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._start_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._end_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -187,11 +187,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) - self._check_time_remaining(trainer) @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 26af335f7be93..664edf6214de3 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -14,9 +14,10 @@ import inspect import os import sys +from collections.abc import Iterable from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import torch import yaml @@ -65,11 +66,11 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any # LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch: LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] # Type aliases intended for convenience of CLI developers -ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]] +ArgsType = Optional[Union[list[str], dict[str, Any], Namespace]] OptimizerCallable = Callable[[Iterable], Optimizer] LRSchedulerCallable = Callable[[Optimizer], Union[LRScheduler, ReduceLROnPlateau]] @@ -99,24 +100,24 @@ def __init__( if not _JSONARGPARSE_SIGNATURES_AVAILABLE: raise ModuleNotFoundError(f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}") super().__init__(*args, description=description, env_prefix=env_prefix, default_env=default_env, **kwargs) - self.callback_keys: List[str] = [] + self.callback_keys: list[str] = [] # separate optimizers and lr schedulers to know which were added - self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} + self._optimizers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} + self._lr_schedulers: dict[str, tuple[Union[type, tuple[type, ...]], str]] = {} def add_lightning_class_args( self, lightning_class: Union[ Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], - Type[Trainer], - Type[LightningModule], - Type[LightningDataModule], - Type[Callback], + type[Trainer], + type[LightningModule], + type[LightningDataModule], + type[Callback], ], nested_key: str, subclass_mode: bool = False, required: bool = True, - ) -> List[str]: + ) -> list[str]: """Adds arguments from a lightning class to a nested key of the parser. Args: @@ -153,7 +154,7 @@ def add_lightning_class_args( def add_optimizer_args( self, - optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), + optimizer_class: Union[type[Optimizer], tuple[type[Optimizer], ...]] = (Optimizer,), nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: @@ -169,7 +170,7 @@ def add_optimizer_args( assert all(issubclass(o, Optimizer) for o in optimizer_class) else: assert issubclass(optimizer_class, Optimizer) - kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} + kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) else: @@ -178,7 +179,7 @@ def add_optimizer_args( def add_lr_scheduler_args( self, - lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, + lr_scheduler_class: Union[LRSchedulerType, tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: @@ -195,7 +196,7 @@ def add_lr_scheduler_args( assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) - kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} + kwargs: dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) else: @@ -305,14 +306,14 @@ class LightningCLI: def __init__( self, - model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, - datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, - save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, - save_config_kwargs: Optional[Dict[str, Any]] = None, - trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, - trainer_defaults: Optional[Dict[str, Any]] = None, + model_class: Optional[Union[type[LightningModule], Callable[..., LightningModule]]] = None, + datamodule_class: Optional[Union[type[LightningDataModule], Callable[..., LightningDataModule]]] = None, + save_config_callback: Optional[type[SaveConfigCallback]] = SaveConfigCallback, + save_config_kwargs: Optional[dict[str, Any]] = None, + trainer_class: Union[type[Trainer], Callable[..., Trainer]] = Trainer, + trainer_defaults: Optional[dict[str, Any]] = None, seed_everything_default: Union[bool, int] = True, - parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, + parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, @@ -393,7 +394,7 @@ def __init__( if self.subcommand is not None: self._run_subcommand(self.subcommand) - def _setup_parser_kwargs(self, parser_kwargs: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: subcommand_names = self.subcommands().keys() main_kwargs = {k: v for k, v in parser_kwargs.items() if k not in subcommand_names} subparser_kwargs = {k: v for k, v in parser_kwargs.items() if k in subcommand_names} @@ -409,12 +410,12 @@ def init_parser(self, **kwargs: Any) -> LightningArgumentParser: return parser def setup_parser( - self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] + self, add_subcommands: bool, main_kwargs: dict[str, Any], subparser_kwargs: dict[str, Any] ) -> None: """Initialize and setup the parser, subcommands, and arguments.""" self.parser = self.init_parser(**main_kwargs) if add_subcommands: - self._subcommand_method_arguments: Dict[str, List[str]] = {} + self._subcommand_method_arguments: dict[str, list[str]] = {} self._add_subcommands(self.parser, **subparser_kwargs) else: self._add_arguments(self.parser) @@ -469,7 +470,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """ @staticmethod - def subcommands() -> Dict[str, Set[str]]: + def subcommands() -> dict[str, set[str]]: """Defines the list of available subcommands and the arguments to skip.""" return { "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, @@ -480,7 +481,7 @@ def subcommands() -> Dict[str, Set[str]]: def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: """Adds subcommands to the input parser.""" - self._subcommand_parsers: Dict[str, LightningArgumentParser] = {} + self._subcommand_parsers: dict[str, LightningArgumentParser] = {} parser_subcommands = parser.add_subcommands() # the user might have passed a builder function trainer_class = ( @@ -497,11 +498,11 @@ def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> No self._subcommand_parsers[subcommand] = subcommand_parser parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description) - def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: + def _prepare_subcommand_parser(self, klass: type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: parser = self.init_parser(**kwargs) self._add_arguments(parser) # subcommand arguments - skip: Set[Union[str, int]] = set(self.subcommands()[subcommand]) + skip: set[Union[str, int]] = set(self.subcommands()[subcommand]) added = parser.add_method_arguments(klass, subcommand, skip=skip) # need to save which arguments were added to pass them to the method later self._subcommand_method_arguments[subcommand] = added @@ -571,7 +572,7 @@ def instantiate_trainer(self, **kwargs: Any) -> Trainer: trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} return self._instantiate_trainer(trainer_config, extra_callbacks) - def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: + def _instantiate_trainer(self, config: dict[str, Any], callbacks: list[Callback]) -> Trainer: key = "callbacks" if key in config: if config[key] is None: @@ -632,8 +633,8 @@ def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) - parser = self._parser(subcommand) def get_automatic( - class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] - ) -> List[str]: + class_type: Union[type, tuple[type, ...]], register: dict[str, tuple[Union[type, tuple[type, ...]], str]] + ) -> list[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): @@ -704,7 +705,7 @@ def _run_subcommand(self, subcommand: str) -> None: if callable(after_fn): after_fn() - def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: + def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]: """Prepares the keyword arguments to pass to the subcommand to run.""" fn_kwargs = { k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand] @@ -730,26 +731,26 @@ def _set_seed(self) -> None: self.config["seed_everything"] = config_seed -def _class_path_from_class(class_type: Type) -> str: +def _class_path_from_class(class_type: type) -> str: return class_type.__module__ + "." + class_type.__name__ def _global_add_class_path( - class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None -) -> Dict[str, Any]: + class_type: type, init_args: Optional[Union[Namespace, dict[str, Any]]] = None +) -> dict[str, Any]: if isinstance(init_args, Namespace): init_args = init_args.as_dict() return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} -def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: - def add_class_path(init_args: Namespace) -> Dict[str, Any]: +def _add_class_path_generator(class_type: type) -> Callable[[Namespace], dict[str, Any]]: + def add_class_path(init_args: Namespace) -> dict[str, Any]: return _global_add_class_path(class_type, init_args) return add_class_path -def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: +def instantiate_class(args: Union[Any, tuple[Any, ...]], init: dict[str, Any]) -> Any: """Instantiates a class with the given args and init. Args: @@ -790,7 +791,7 @@ def __init__(self, cli: LightningCLI, key: str) -> None: self.cli = cli self.key = key - def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: + def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType: hparams = self.cli.config_dump.get(self.key, {}) if "class_path" in hparams: # To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the @@ -808,7 +809,7 @@ def __call__(self, class_type: Type[ModuleType], *args: Any, **kwargs: Any) -> M return class_type(*args, **kwargs) -def instantiate_module(class_type: Type[ModuleType], config: Dict[str, Any]) -> ModuleType: +def instantiate_module(class_type: type[ModuleType], config: dict[str, Any]) -> ModuleType: parser = ArgumentParser(exit_on_error=False) if "_class_path" in config: parser.add_subclass_arguments(class_type, "module", fail_untyped=False) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 6cb8f79f09284..0c7a9840e2219 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -14,7 +14,8 @@ """LightningDataModule for loading DataLoaders with ease.""" import inspect -from typing import IO, Any, Dict, Iterable, Optional, Union, cast +from collections.abc import Iterable +from typing import IO, Any, Optional, Union, cast from lightning_utilities import apply_to_collection from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -147,7 +148,7 @@ def predict_dataloader() -> EVAL_DATALOADERS: datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign] return datamodule - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """Called when saving a checkpoint, implement to generate and save datamodule state. Returns: @@ -156,7 +157,7 @@ def state_dict(self) -> Dict[str, Any]: """ return {} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict. Args: diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 5495a0262036d..0b0ab14244e38 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -670,7 +670,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): class CheckpointHooks: """Hooks to be used with Checkpointing.""" - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is your chance to restore this. @@ -689,7 +689,7 @@ def on_load_checkpoint(self, checkpoint): """ - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save. diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index 94ece0039d4f4..3a01cd2fe9a7c 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -15,9 +15,10 @@ import inspect import types from argparse import Namespace +from collections.abc import Iterator, MutableMapping, Sequence from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Iterator, List, MutableMapping, Optional, Sequence, Union +from typing import Any, Optional, Union from lightning.fabric.utilities.data import AttributeDict from lightning.pytorch.utilities.parsing import save_hyperparameters @@ -41,7 +42,7 @@ def _given_hyperparameters_context(hparams: dict, instantiator: str) -> Iterator class HyperparametersMixin: - __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] + __jit_unused_properties__: list[str] = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index d8374ef7ea5e8..f1d1da924eac4 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -16,6 +16,7 @@ import logging import numbers import weakref +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from io import BytesIO from pathlib import Path @@ -24,14 +25,8 @@ TYPE_CHECKING, Any, Callable, - Dict, - Generator, - List, Literal, - Mapping, Optional, - Sequence, - Tuple, Union, cast, overload, @@ -86,7 +81,7 @@ log = logging.getLogger(__name__) MODULE_OPTIMIZERS = Union[ - Optimizer, LightningOptimizer, _FabricOptimizer, List[Optimizer], List[LightningOptimizer], List[_FabricOptimizer] + Optimizer, LightningOptimizer, _FabricOptimizer, list[Optimizer], list[LightningOptimizer], list[_FabricOptimizer] ] @@ -100,7 +95,7 @@ class LightningModule( ): # Below is for property support of JIT # since none of these are important when using JIT, we are going to ignore them. - __jit_unused_properties__: List[str] = ( + __jit_unused_properties__: list[str] = ( [ "example_input_array", "on_gpu", @@ -132,19 +127,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._trainer: Optional[pl.Trainer] = None # attributes that can be set by user - self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None + self._example_input_array: Optional[Union[Tensor, tuple, dict]] = None self._automatic_optimization: bool = True self._strict_loading: Optional[bool] = None # attributes used internally self._current_fx_name: Optional[str] = None - self._param_requires_grad_state: Dict[str, bool] = {} - self._metric_attributes: Optional[Dict[int, str]] = None - self._compiler_ctx: Optional[Dict[str, Any]] = None + self._param_requires_grad_state: dict[str, bool] = {} + self._metric_attributes: Optional[dict[int, str]] = None + self._compiler_ctx: Optional[dict[str, Any]] = None # attributes only used when using fabric self._fabric: Optional[lf.Fabric] = None - self._fabric_optimizers: List[_FabricOptimizer] = [] + self._fabric_optimizers: list[_FabricOptimizer] = [] # access to device mesh in `conigure_model()` hook self._device_mesh: Optional[DeviceMesh] = None @@ -152,10 +147,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @overload def optimizers( self, use_pl_optimizer: Literal[True] = True - ) -> Union[LightningOptimizer, List[LightningOptimizer]]: ... + ) -> Union[LightningOptimizer, list[LightningOptimizer]]: ... @overload - def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: ... + def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, list[Optimizer]]: ... @overload def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ... @@ -190,7 +185,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: # multiple opts return opts - def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]: + def lr_schedulers(self) -> Union[None, list[LRSchedulerPLType], LRSchedulerPLType]: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. Returns: @@ -202,7 +197,7 @@ def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLTyp return None # ignore other keys "interval", "frequency", etc. - lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] + lr_schedulers: list[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] # single scheduler if len(lr_schedulers) == 1: @@ -240,7 +235,7 @@ def fabric(self, fabric: Optional["lf.Fabric"]) -> None: self._fabric = fabric @property - def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: + def example_input_array(self) -> Optional[Union[Tensor, tuple, dict]]: """The example input array is a specification of what the module can consume in the :meth:`forward` method. The return type is interpreted as follows: @@ -255,7 +250,7 @@ def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: return self._example_input_array @example_input_array.setter - def example_input_array(self, example: Optional[Union[Tensor, Tuple, Dict]]) -> None: + def example_input_array(self, example: Optional[Union[Tensor, tuple, dict]]) -> None: self._example_input_array = example @property @@ -318,7 +313,7 @@ def logger(self) -> Optional[Union[Logger, FabricLogger]]: return self._trainer.logger if self._trainer is not None else None @property - def loggers(self) -> Union[List[Logger], List[FabricLogger]]: + def loggers(self) -> Union[list[Logger], list[FabricLogger]]: """Reference to the list of loggers in the Trainer.""" if self._fabric is not None: return self._fabric.loggers @@ -599,7 +594,7 @@ def log_dict( if self._fabric is not None: return self._log_dict_through_fabric(dictionary=dictionary, logger=logger) - kwargs: Dict[str, bool] = {} + kwargs: dict[str, bool] = {} if isinstance(dictionary, MetricCollection): kwargs["keep_base"] = False @@ -665,8 +660,8 @@ def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor return value def all_gather( - self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False - ) -> Union[Tensor, Dict, List, Tuple]: + self, data: Union[Tensor, dict, list, tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[Tensor, dict, list, tuple]: r"""Gather tensors or collections of tensors from multiple processes. This method needs to be called on all processes and the tensors need to have the same shape across all @@ -1417,7 +1412,7 @@ def to_torchscript( method: Optional[str] = "script", example_inputs: Optional[Any] = None, **kwargs: Any, - ) -> Union[ScriptModule, Dict[str, ScriptModule]]: + ) -> Union[ScriptModule, dict[str, ScriptModule]]: """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing, please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that are @@ -1594,7 +1589,7 @@ def load_from_checkpoint( return cast(Self, loaded) @override - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = dict(self.__dict__) state["_trainer"] = None return state diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index 777dca0b51dfe..46126e212378e 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager from dataclasses import fields -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, overload +from typing import Any, Callable, Optional, Union, overload from weakref import proxy import torch @@ -172,7 +173,7 @@ def __getattr__(self, item: Any) -> Any: def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", -) -> Tuple[List[Optimizer], List[LRSchedulerConfig]]: +) -> tuple[list[Optimizer], list[LRSchedulerConfig]]: """Calls `LightningModule.configure_optimizers` and parses and validates the output.""" from lightning.pytorch.trainer import call @@ -197,8 +198,8 @@ def _init_optimizers_and_lr_schedulers( def _configure_optimizers( - optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple], -) -> Tuple[List, List, Optional[str]]: + optim_conf: Union[dict[str, Any], list, Optimizer, tuple], +) -> tuple[list, list, Optional[str]]: optimizers, lr_schedulers = [], [] monitor = None @@ -246,7 +247,7 @@ def _configure_optimizers( return optimizers, lr_schedulers, monitor -def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: +def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> list[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization.""" lr_scheduler_configs = [] for scheduler in schedulers: @@ -301,7 +302,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] return lr_scheduler_configs -def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig]: +def _configure_schedulers_manual_opt(schedulers: list) -> list[LRSchedulerConfig]: """Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual optimization.""" lr_scheduler_configs = [] @@ -326,7 +327,7 @@ def _configure_schedulers_manual_opt(schedulers: list) -> List[LRSchedulerConfig return lr_scheduler_configs -def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model: "pl.LightningModule") -> None: +def _validate_scheduler_api(lr_scheduler_configs: list[LRSchedulerConfig], model: "pl.LightningModule") -> None: for config in lr_scheduler_configs: scheduler = config.scheduler if not isinstance(scheduler, _Stateful): @@ -347,7 +348,7 @@ def _validate_scheduler_api(lr_scheduler_configs: List[LRSchedulerConfig], model ) -def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "pl.LightningModule") -> None: +def _validate_multiple_optimizers_support(optimizers: list[Optimizer], model: "pl.LightningModule") -> None: if is_param_in_hook_signature(model.training_step, "optimizer_idx", explicit=True): raise RuntimeError( "Training with multiple optimizers is only supported with manual optimization. Remove the `optimizer_idx`" @@ -362,7 +363,7 @@ def _validate_multiple_optimizers_support(optimizers: List[Optimizer], model: "p ) -def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_configs: List[LRSchedulerConfig]) -> None: +def _validate_optimizers_attached(optimizers: list[Optimizer], lr_scheduler_configs: list[LRSchedulerConfig]) -> None: for config in lr_scheduler_configs: if config.scheduler.optimizer not in optimizers: raise MisconfigurationException( @@ -370,7 +371,7 @@ def _validate_optimizers_attached(optimizers: List[Optimizer], lr_scheduler_conf ) -def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None: +def _validate_optim_conf(optim_conf: dict[str, Any]) -> None: valid_keys = {"optimizer", "lr_scheduler", "monitor"} extra_keys = optim_conf.keys() - valid_keys if extra_keys: @@ -387,15 +388,15 @@ def __init__(self) -> None: super().__init__([torch.zeros(1)], {}) @override - def add_param_group(self, param_group: Dict[Any, Any]) -> None: + def add_param_group(self, param_group: dict[Any, Any]) -> None: pass # Do Nothing @override - def load_state_dict(self, state_dict: Dict[Any, Any]) -> None: + def load_state_dict(self, state_dict: dict[Any, Any]) -> None: pass # Do Nothing @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {} # Return Empty @overload diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 521192f500b53..09d888c56bdcd 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -22,7 +22,7 @@ from copy import deepcopy from enum import Enum from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Union from warnings import warn import torch @@ -51,7 +51,7 @@ def _load_from_checkpoint( - cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], + cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], checkpoint_path: Union[_PATH, IO], map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[_PATH] = None, @@ -115,8 +115,8 @@ def _default_map_location(storage: "UntypedStorage", location: str) -> Optional[ def _load_state( - cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], - checkpoint: Dict[str, Any], + cls: Union[type["pl.LightningModule"], type["pl.LightningDataModule"]], + checkpoint: dict[str, Any], strict: Optional[bool] = None, **cls_kwargs_new: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: @@ -200,8 +200,8 @@ def _load_state( def _convert_loaded_hparams( - model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None -) -> Dict[str, Any]: + model_args: dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None +) -> dict[str, Any]: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: @@ -243,7 +243,7 @@ def update_hparams(hparams: dict, updates: dict) -> None: hparams.update({k: v}) -def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]: +def load_hparams_from_tags_csv(tags_csv: _PATH) -> dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') @@ -281,7 +281,7 @@ def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) - writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> dict[str, Any]: """Load hparams from a file. Args: diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index fd2660228146e..589524e1960b2 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterator, List, Optional, Tuple +from collections.abc import Iterator +from typing import Any, Optional import torch import torch.nn as nn @@ -35,7 +36,7 @@ def __init__(self, size: int, length: int): self.len = length self.data = torch.randn(length, size) - def __getitem__(self, index: int) -> Dict[str, Tensor]: + def __getitem__(self, index: int) -> dict[str, Tensor]: a = self.data[index] b = a + 2 return {"a": a, "b": b} @@ -134,7 +135,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: return {"y": self.step(batch)} - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[LRScheduler]]: + def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[LRScheduler]]: optimizer = torch.optim.SGD(self.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] diff --git a/src/lightning/pytorch/demos/lstm.py b/src/lightning/pytorch/demos/lstm.py index 672b61ad0eff9..9432dd9acd1f8 100644 --- a/src/lightning/pytorch/demos/lstm.py +++ b/src/lightning/pytorch/demos/lstm.py @@ -5,7 +5,8 @@ """ -from typing import Iterator, List, Optional, Sized, Tuple +from collections.abc import Iterator, Sized +from typing import Optional import torch import torch.nn as nn @@ -37,14 +38,14 @@ def init_weights(self) -> None: nn.init.zeros_(self.decoder.bias) nn.init.uniform_(self.decoder.weight, -0.1, 0.1) - def forward(self, input: Tensor, hidden: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: + def forward(self, input: Tensor, hidden: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) decoded = self.decoder(output).view(-1, self.vocab_size) return F.log_softmax(decoded, dim=1), hidden - def init_hidden(self, batch_size: int) -> Tuple[Tensor, Tensor]: + def init_hidden(self, batch_size: int) -> tuple[Tensor, Tensor]: weight = next(self.parameters()) return ( weight.new_zeros(self.nlayers, batch_size, self.nhid), @@ -52,14 +53,14 @@ def init_hidden(self, batch_size: int) -> Tuple[Tensor, Tensor]: ) -class SequenceSampler(Sampler[List[int]]): +class SequenceSampler(Sampler[list[int]]): def __init__(self, dataset: Sized, batch_size: int) -> None: super().__init__() self.dataset = dataset self.batch_size = batch_size self.chunk_size = len(self.dataset) // self.batch_size - def __iter__(self) -> Iterator[List[int]]: + def __iter__(self) -> Iterator[list[int]]: n = len(self.dataset) for i in range(self.chunk_size): yield list(range(i, n - (n % self.batch_size), self.chunk_size)) @@ -72,12 +73,12 @@ class LightningLSTM(LightningModule): def __init__(self, vocab_size: int = 33278): super().__init__() self.model = SimpleLSTM(vocab_size=vocab_size) - self.hidden: Optional[Tuple[Tensor, Tensor]] = None + self.hidden: Optional[tuple[Tensor, Tensor]] = None def on_train_epoch_end(self) -> None: self.hidden = None - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: input, target = batch if self.hidden is None: self.hidden = self.model.init_hidden(input.size(0)) diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 992527ab67296..73f46d4dc0986 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -16,7 +16,8 @@ import random import time import urllib -from typing import Any, Callable, Optional, Sized, Tuple, Union +from collections.abc import Sized +from typing import Any, Callable, Optional, Union from urllib.error import HTTPError from warnings import warn @@ -63,7 +64,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + def __getitem__(self, idx: int) -> tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) @@ -99,7 +100,7 @@ def _download(self, data_folder: str) -> None: urllib.request.urlretrieve(url, fpath) # noqa: S310 @staticmethod - def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> Tuple[Tensor, Tensor]: + def _try_load(path_data: str, trials: int = 30, delta: float = 1.0) -> tuple[Tensor, Tensor]: """Resolving loading from the same time from multiple concurrent processes.""" res, exception = None, None assert trials, "at least some trial has to be set" diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 58cf30cbca7b7..eca86b4cb4dc7 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -8,7 +8,7 @@ import math import os from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -119,7 +119,7 @@ def vocab_size(self) -> int: def __len__(self) -> int: return len(self.data) // self.block_size - 1 - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: + def __getitem__(self, index: int) -> tuple[Tensor, Tensor]: start = index * self.block_size end = start + self.block_size inputs = self.data[start:end] @@ -143,8 +143,8 @@ def download(destination: Path) -> None: class Dictionary: def __init__(self) -> None: - self.word2idx: Dict[str, int] = {} - self.idx2word: List[str] = [] + self.word2idx: dict[str, int] = {} + self.idx2word: list[str] = [] def add_word(self, word: str) -> int: if word not in self.word2idx: @@ -156,7 +156,7 @@ def __len__(self) -> int: return len(self.idx2word) -def tokenize(path: Path) -> Tuple[Tensor, Dictionary]: +def tokenize(path: Path) -> tuple[Tensor, Dictionary]: dictionary = Dictionary() assert os.path.exists(path) @@ -169,10 +169,10 @@ def tokenize(path: Path) -> Tuple[Tensor, Dictionary]: # Tokenize file content with open(path, encoding="utf8") as f: - idss: List[Tensor] = [] + idss: list[Tensor] = [] for line in f: words = line.split() + [""] - ids: List[int] = [] + ids: list[int] = [] for word in words: ids.append(dictionary.word2idx[word]) idss.append(torch.tensor(ids).type(torch.int64)) @@ -188,7 +188,7 @@ def __init__(self, vocab_size: int = 33278) -> None: def forward(self, inputs: Tensor, target: Tensor) -> Tensor: return self.model(inputs, target) - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: inputs, target = batch output = self(inputs, target) loss = torch.nn.functional.nll_loss(output, target.view(-1)) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 277af5c85f539..9c05317655129 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -19,7 +19,8 @@ import logging import os from argparse import Namespace -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -305,7 +306,7 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) self.experiment.log_parameters(params) @@ -410,7 +411,7 @@ def version(self) -> str: return self._future_experiment_key - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Save the experiment id in case an experiment object already exists, diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index caca0c181c6ff..8606264dc3cdb 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -21,7 +21,7 @@ import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -52,9 +52,9 @@ class ExperimentWriter(_FabricExperimentWriter): def __init__(self, log_dir: str) -> None: super().__init__(log_dir=log_dir) - self.hparams: Dict[str, Any] = {} + self.hparams: dict[str, Any] = {} - def log_hparams(self, params: Dict[str, Any]) -> None: + def log_hparams(self, params: dict[str, Any]) -> None: """Record hparams.""" self.hparams.update(params) @@ -144,7 +144,7 @@ def save_dir(self) -> str: @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) self.experiment.log_hparams(params) diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 40e8ed8c4a13e..668fe39cb67d2 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -18,7 +18,8 @@ import statistics from abc import ABC from collections import defaultdict -from typing import Any, Callable, Dict, Mapping, Optional, Sequence +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Optional from typing_extensions import override @@ -101,7 +102,7 @@ def merge_dicts( # pragma: no cover dicts: Sequence[Mapping], agg_key_funcs: Optional[Mapping] = None, default_func: Callable[[Sequence[float]], float] = statistics.mean, -) -> Dict: +) -> dict: """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. Args: @@ -137,7 +138,7 @@ def merge_dicts( # pragma: no cover """ agg_key_funcs = agg_key_funcs or {} keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) - d_out: Dict = defaultdict(dict) + d_out: dict = defaultdict(dict) for k in keys: fn = agg_key_funcs.get(k) values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None] diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ec990b634a6c4..e3d99987b7f58 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -21,9 +21,10 @@ import re import tempfile from argparse import Namespace +from collections.abc import Mapping from pathlib import Path from time import time -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import yaml from lightning_utilities.core.imports import RequirementCache @@ -117,7 +118,7 @@ def __init__( experiment_name: str = "lightning_logs", run_name: Optional[str] = None, tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"), - tags: Optional[Dict[str, Any]] = None, + tags: Optional[dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", @@ -140,7 +141,7 @@ def __init__( self._run_id = run_id self.tags = tags self._log_model = log_model - self._logged_model_time: Dict[str, float] = {} + self._logged_model_time: dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None self._prefix = prefix self._artifact_location = artifact_location @@ -227,7 +228,7 @@ def experiment_id(self) -> Optional[str]: @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) @@ -249,7 +250,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) from mlflow.entities import Metric metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - metrics_list: List[Metric] = [] + metrics_list: list[Metric] = [] timestamp_ms = int(time() * 1000) for k, v in metrics.items(): diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 691dbe0ba2ff7..a363f589b29b4 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -20,8 +20,9 @@ import logging import os from argparse import Namespace +from collections.abc import Generator from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor @@ -286,8 +287,8 @@ def _retrieve_run_data(self) -> None: self._run_name = "offline-name" @property - def _neptune_init_args(self) -> Dict: - args: Dict = {} + def _neptune_init_args(self) -> dict: + args: dict = {} # Backward compatibility in case of previous version retrieval with contextlib.suppress(AttributeError): args = self._neptune_run_kwargs @@ -337,13 +338,13 @@ def _verify_input_arguments( " parameters." ) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # Run instance can't be pickled state["_run_instance"] = None return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: import neptune self.__dict__ = state @@ -395,7 +396,7 @@ def run(self) -> "Run": @override @rank_zero_only @_catch_inactive - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: r"""Log hyperparameters to the run. Hyperparameters will be logged under the "/hyperparams" namespace. @@ -443,7 +444,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @override @rank_zero_only @_catch_inactive - def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: """Log metrics (numeric values) in Neptune runs. Args: @@ -563,16 +564,16 @@ def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> st return model_path.replace(os.sep, "/") @classmethod - def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]: + def _get_full_model_names_from_exp_structure(cls, exp_structure: dict[str, Any], namespace: str) -> set[str]: """Returns all paths to properties which were already logged in `namespace`""" - structure_keys: List[str] = namespace.split(cls.LOGGER_JOIN_CHAR) + structure_keys: list[str] = namespace.split(cls.LOGGER_JOIN_CHAR) for key in structure_keys: exp_structure = exp_structure[key] uploaded_models_dict = exp_structure return set(cls._dict_paths(uploaded_models_dict)) @classmethod - def _dict_paths(cls, d: Dict[str, Any], path_in_build: Optional[str] = None) -> Generator: + def _dict_paths(cls, d: dict[str, Any], path_in_build: Optional[str] = None) -> Generator: for k, v in d.items(): path = f"{path_in_build}/{k}" if path_in_build is not None else k if not isinstance(v, dict): diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index 88e026f6945e0..e70c89269b166 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -18,7 +18,7 @@ import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from torch import Tensor from typing_extensions import override @@ -108,7 +108,7 @@ def __init__( f"{str(_TENSORBOARD_AVAILABLE)}" ) self._log_graph = log_graph and _TENSORBOARD_AVAILABLE - self.hparams: Union[Dict[str, Any], Namespace] = {} + self.hparams: Union[dict[str, Any], Namespace] = {} @property @override @@ -153,7 +153,7 @@ def save_dir(self) -> str: @override @rank_zero_only def log_hyperparams( - self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None + self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 2ff9cbd24eca5..c763071af5644 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -14,7 +14,7 @@ """Utilities for loggers.""" from pathlib import Path -from typing import Any, List, Tuple, Union +from typing import Any, Union from torch import Tensor @@ -22,14 +22,14 @@ from lightning.pytorch.callbacks import Checkpoint -def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: +def _version(loggers: list[Any], separator: str = "_") -> Union[int, str]: if len(loggers) == 1: return loggers[0].version # Concatenate versions together, removing duplicates and preserving order return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) -def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> List[Tuple[float, str, float, str]]: +def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) -> list[tuple[float, str, float, str]]: """Return the checkpoints to be logged. Args: diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 20f8d02a7ab9b..2429748f73179 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -18,8 +18,9 @@ import os from argparse import Namespace +from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch.nn as nn from lightning_utilities.core.imports import RequirementCache @@ -320,7 +321,7 @@ def __init__( self._log_model = log_model self._prefix = prefix self._experiment = experiment - self._logged_model_time: Dict[str, float] = {} + self._logged_model_time: dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None # paths are processed as strings @@ -332,7 +333,7 @@ def __init__( project = project or os.environ.get("WANDB_PROJECT", "lightning_logs") # set wandb init arguments - self._wandb_init: Dict[str, Any] = { + self._wandb_init: dict[str, Any] = { "name": name, "project": project, "dir": save_dir or dir, @@ -348,7 +349,7 @@ def __init__( self._id = self._wandb_init.get("id") self._checkpoint_name = checkpoint_name - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: import wandb # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called. @@ -421,7 +422,7 @@ def watch( @override @rank_zero_only - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _sanitize_callable_params(params) params = _convert_json_serializable(params) @@ -442,8 +443,8 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) def log_table( self, key: str, - columns: Optional[List[str]] = None, - data: Optional[List[List[Any]]] = None, + columns: Optional[list[str]] = None, + data: Optional[list[list[Any]]] = None, dataframe: Any = None, step: Optional[int] = None, ) -> None: @@ -461,8 +462,8 @@ def log_table( def log_text( self, key: str, - columns: Optional[List[str]] = None, - data: Optional[List[List[str]]] = None, + columns: Optional[list[str]] = None, + data: Optional[list[list[str]]] = None, dataframe: Any = None, step: Optional[int] = None, ) -> None: @@ -475,7 +476,7 @@ def log_text( self.log_table(key, columns, data, dataframe, step) @rank_zero_only - def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_image(self, key: str, images: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). @@ -495,7 +496,7 @@ def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_audio(self, key: str, audios: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_audio(self, key: str, audios: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: r"""Log audios (numpy arrays, or file paths). Args: @@ -521,7 +522,7 @@ def log_audio(self, key: str, audios: List[Any], step: Optional[int] = None, **k self.log_metrics(metrics, step) # type: ignore[arg-type] @rank_zero_only - def log_video(self, key: str, videos: List[Any], step: Optional[int] = None, **kwargs: Any) -> None: + def log_video(self, key: str, videos: list[Any], step: Optional[int] = None, **kwargs: Any) -> None: """Log videos (numpy arrays, or file paths). Args: diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 78573c45aab76..d007466ee3b1c 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -15,8 +15,9 @@ import shutil import sys from collections import ChainMap, OrderedDict, defaultdict +from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor @@ -67,17 +68,17 @@ def __init__( self.verbose = verbose self.inference_mode = inference_mode self.batch_progress = _BatchProgress() # across dataloaders - self._max_batches: List[Union[int, float]] = [] + self._max_batches: list[Union[int, float]] = [] self._results = _ResultCollection(training=False) - self._logged_outputs: List[_OUT_DICT] = [] + self._logged_outputs: list[_OUT_DICT] = [] self._has_run: bool = False self._trainer_fn = trainer_fn self._stage = stage self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None - self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) + self._seen_batches_per_dataloader: defaultdict[int, int] = defaultdict(int) self._last_val_dl_reload_epoch = float("-inf") self._module_mode = _ModuleMode() self._restart_stage = RestartStage.NONE @@ -90,7 +91,7 @@ def num_dataloaders(self) -> int: return len(combined_loader.flattened) @property - def max_batches(self) -> List[Union[int, float]]: + def max_batches(self) -> list[Union[int, float]]: """The max number of batches to run per dataloader.""" max_batches = self._max_batches if not self.trainer.sanity_checking: @@ -114,7 +115,7 @@ def _is_sequential(self) -> bool: return self._combined_loader._mode == "sequential" @_no_grad_context - def run(self) -> List[_OUT_DICT]: + def run(self) -> list[_OUT_DICT]: self.setup_data() if self.skip: return [] @@ -280,7 +281,7 @@ def on_run_start(self) -> None: self._on_evaluation_start() self._on_evaluation_epoch_start() - def on_run_end(self) -> List[_OUT_DICT]: + def on_run_end(self) -> list[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` self.trainer._logger_connector.epoch_end_reached() @@ -508,7 +509,7 @@ def _verify_dataloader_idx_requirement(self) -> None: ) @staticmethod - def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]: + def _get_keys(data: dict) -> Iterable[tuple[str, ...]]: for k, v in data.items(): if isinstance(v, dict): for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys): @@ -527,7 +528,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]: return _EvaluationLoop._find_value(result, rest) @staticmethod - def _print_results(results: List[_OUT_DICT], stage: str) -> None: + def _print_results(results: list[_OUT_DICT], stage: str) -> None: # remove the dl idx suffix results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results] metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys} @@ -544,7 +545,7 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None: term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120 max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2)) - rows: List[List[Any]] = [[] for _ in metrics_paths] + rows: list[list[Any]] = [[] for _ in metrics_paths] for result in results: for metric, row in zip(metrics_paths, rows): diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index e699321a4d23e..92ec95a9e2f58 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, List, Optional +from collections.abc import Iterator +from typing import Any, Optional from typing_extensions import override @@ -97,7 +98,7 @@ def __init__(self, prefetch_batches: int = 1) -> None: if prefetch_batches < 0: raise ValueError("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches - self.batches: List[Any] = [] + self.batches: list[Any] = [] @override def __iter__(self) -> "_PrefetchDataFetcher": diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index e20088acd0af3..31d6724a043a3 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import torch from typing_extensions import override @@ -104,7 +104,7 @@ def __init__( self._data_source = _DataLoaderSource(None, "train_dataloader") self._combined_loader: Optional[CombinedLoader] = None - self._combined_loader_states_to_load: List[Dict[str, Any]] = [] + self._combined_loader_states_to_load: list[dict[str, Any]] = [] self._data_fetcher: Optional[_DataFetcher] = None self._last_train_dl_reload_epoch = float("-inf") self._restart_stage = RestartStage.NONE @@ -504,14 +504,14 @@ def teardown(self) -> None: self.epoch_loop.teardown() @override - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: state_dict = super().on_save_checkpoint() if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()): state_dict["combined_loader"] = loader_states return state_dict @override - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: self._combined_loader_states_to_load = state_dict.get("combined_loader", []) super().on_load_checkpoint(state_dict) diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index 111377a222b3f..daad309cd75d4 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Optional import lightning.pytorch as pl from lightning.pytorch.loops.progress import _BaseProgress @@ -41,7 +41,7 @@ def restarting(self, restarting: bool) -> None: def reset_restart_stage(self) -> None: pass - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: """Called when saving a model checkpoint, use to persist loop state. Returns: @@ -50,10 +50,10 @@ def on_save_checkpoint(self) -> Dict: """ return {} - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" - def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict: + def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> dict: """The state dict is determined by the state and progress of this loop and all its children. Args: @@ -77,7 +77,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di def load_state_dict( self, - state_dict: Dict, + state_dict: dict, prefix: str = "", ) -> None: """Loads the state of this loop and all its children.""" @@ -88,7 +88,7 @@ def load_state_dict( self.restarting = True self._loaded_from_state_dict = True - def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None: + def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: for k, v in self.__dict__.items(): key = prefix + k if key not in state_dict: diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index 2ce6acab11a37..e19b5761c4d4b 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict +from collections.abc import Mapping from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict +from typing import Any, Callable, Optional import torch from torch import Tensor @@ -46,7 +48,7 @@ class ClosureResult(OutputResult): closure_loss: Optional[Tensor] loss: Optional[Tensor] = field(init=False, default=None) - extra: Dict[str, Any] = field(default_factory=dict) + extra: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: self._clone_loss() @@ -83,7 +85,7 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT, normalize: return cls(closure_loss, extra=extra) @override - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: return {"loss": self.loss, **self.extra} @@ -145,7 +147,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: return self._result.loss -_OUTPUTS_TYPE = Dict[str, Any] +_OUTPUTS_TYPE = dict[str, Any] class _AutomaticOptimization(_Loop): diff --git a/src/lightning/pytorch/loops/optimization/closure.py b/src/lightning/pytorch/loops/optimization/closure.py index 4b550166b721e..e45262a067f52 100644 --- a/src/lightning/pytorch/loops/optimization/closure.py +++ b/src/lightning/pytorch/loops/optimization/closure.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Generic, Optional, TypeVar from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -22,7 +22,7 @@ @dataclass class OutputResult: - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: raise NotImplementedError diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index d8a4f1968c3b8..e1aabcbf42976 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import suppress from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any from torch import Tensor from typing_extensions import override @@ -40,7 +40,7 @@ class ManualResult(OutputResult): """ - extra: Dict[str, Any] = field(default_factory=dict) + extra: dict[str, Any] = field(default_factory=dict) @classmethod def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult": @@ -61,11 +61,11 @@ def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "Manual return cls(extra=extra) @override - def asdict(self) -> Dict[str, Any]: + def asdict(self) -> dict[str, Any]: return self.extra -_OUTPUTS_TYPE = Dict[str, Any] +_OUTPUTS_TYPE = dict[str, Any] class _ManualOptimization(_Loop): diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 9002e6280ffc6..7044ccea87a7f 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict -from typing import Any, Iterator, List, Optional, Union +from collections.abc import Iterator +from typing import Any, Optional, Union import torch from lightning_utilities import WarningCache @@ -50,17 +51,17 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None: super().__init__(trainer) self.inference_mode = inference_mode # dataloaders x batches x samples. used by PredictionWriter - self.epoch_batch_indices: List[List[List[int]]] = [] - self.current_batch_indices: List[int] = [] # used by PredictionWriter + self.epoch_batch_indices: list[list[list[int]]] = [] + self.current_batch_indices: list[int] = [] # used by PredictionWriter self.batch_progress = _Progress() # across dataloaders - self.max_batches: List[Union[int, float]] = [] + self.max_batches: list[Union[int, float]] = [] self._warning_cache = WarningCache() self._data_source = _DataLoaderSource(None, "predict_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None self._results = None # for `trainer._results` access - self._predictions: List[List[Any]] = [] # dataloaders x batches + self._predictions: list[list[Any]] = [] # dataloaders x batches self._return_predictions = False self._module_mode = _ModuleMode() @@ -82,7 +83,7 @@ def return_predictions(self, return_predictions: Optional[bool] = None) -> None: self._return_predictions = return_supported if return_predictions is None else return_predictions @property - def predictions(self) -> List[Any]: + def predictions(self) -> list[Any]: """The cached predictions.""" if self._predictions == []: return self._predictions @@ -297,7 +298,7 @@ def _build_step_args_from_hook_kwargs(self, hook_kwargs: OrderedDict, step_hook_ kwargs.pop("batch_idx", None) return tuple(kwargs.values()) - def _get_batch_indices(self, dataloader: object) -> List[List[int]]: # batches x samples + def _get_batch_indices(self, dataloader: object) -> list[list[int]]: # batches x samples """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~lightning.pytorch.overrides.distributed._IndexBatchSamplerWrapper`.""" batch_sampler = getattr(dataloader, "batch_sampler", None) diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 6880b24f70c65..42e5de642aa32 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Type from typing_extensions import override @@ -174,7 +173,7 @@ def increment_completed(self) -> None: self.current.completed += 1 @classmethod - def from_defaults(cls, tracker_cls: Type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": + def from_defaults(cls, tracker_cls: type[_ReadyCompletedTracker], **kwargs: int) -> "_Progress": """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 1c749de3a1b6d..7cdf7888bbfe2 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -14,7 +14,7 @@ import math from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from typing_extensions import override @@ -390,13 +390,13 @@ def teardown(self) -> None: self.val_loop.teardown() @override - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: state_dict = super().on_save_checkpoint() state_dict["_batches_that_stepped"] = self._batches_that_stepped return state_dict @override - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) def _accumulated_batches_reached(self) -> bool: diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 99ea5c4254d62..2aaf877c8913d 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from contextlib import contextmanager -from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Callable, Optional import torch import torch.distributed as dist @@ -52,7 +53,7 @@ def _parse_loop_limits( min_epochs: Optional[int], max_epochs: Optional[int], trainer: "pl.Trainer", -) -> Tuple[int, int]: +) -> tuple[int, int]: """This utility computes the default values for the minimum and maximum number of steps and epochs given the values the user has selected. @@ -159,7 +160,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: raise TypeError(f"`{type(self).__name__}` needs to be a Loop.") if not hasattr(self, "inference_mode"): raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined") - context_manager: Type[ContextManager] + context_manager: type[AbstractContextManager] if _distributed_is_initialized() and dist.get_backend() == "gloo": # gloo backend does not work properly. # https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110 @@ -181,7 +182,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any: def _verify_dataloader_idx_requirement( - hooks: Tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" + hooks: tuple[str, ...], is_expected: bool, stage: RunningStage, pl_module: "pl.LightningModule" ) -> None: for hook in hooks: fx = getattr(pl_module, hook) diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index e4b65285538f1..196008b7ed29f 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sized, Union, cast +from collections.abc import Iterable, Iterator, Sized +from typing import Any, Callable, Optional, Union, cast import torch from torch import Tensor @@ -27,7 +28,7 @@ def _find_tensors( obj: Union[Tensor, list, tuple, dict, Any], -) -> Union[List[Tensor], itertools.chain]: # pragma: no-cover +) -> Union[list[Tensor], itertools.chain]: # pragma: no-cover """Recursively find all tensors contained in the specified object.""" if isinstance(obj, Tensor): return [obj] @@ -201,7 +202,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: assert self.num_samples >= 1 or self.total_size == 0 @override - def __iter__(self) -> Iterator[List[int]]: + def __iter__(self) -> Iterator[list[int]]: if not isinstance(self.dataset, Sized): raise TypeError("The given dataset must implement the `__len__` method.") if self.shuffle: @@ -238,7 +239,7 @@ class _IndexBatchSamplerWrapper: def __init__(self, batch_sampler: _SizedIterable) -> None: # do not call super().__init__() on purpose - self.seen_batch_indices: List[List[int]] = [] + self.seen_batch_indices: list[list[int]] = [] self.__dict__ = { k: v @@ -246,9 +247,9 @@ def __init__(self, batch_sampler: _SizedIterable) -> None: if k not in ("__next__", "__iter__", "__len__", "__getstate__") } self._batch_sampler = batch_sampler - self._iterator: Optional[Iterator[List[int]]] = None + self._iterator: Optional[Iterator[list[int]]] = None - def __next__(self) -> List[int]: + def __next__(self) -> list[int]: assert self._iterator is not None batch = next(self._iterator) self.seen_batch_indices.append(batch) @@ -262,7 +263,7 @@ def __iter__(self) -> Self: def __len__(self) -> int: return len(self._batch_sampler) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() state["_iterator"] = None # cannot pickle 'generator' object return state diff --git a/src/lightning/pytorch/plugins/io/wrapper.py b/src/lightning/pytorch/plugins/io/wrapper.py index 6e918b836e320..548bc1fb15cac 100644 --- a/src/lightning/pytorch/plugins/io/wrapper.py +++ b/src/lightning/pytorch/plugins/io/wrapper.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from typing_extensions import override @@ -66,7 +66,7 @@ def remove_checkpoint(self, *args: Any, **kwargs: Any) -> None: self.checkpoint_io.remove_checkpoint(*args, **kwargs) @override - def load_checkpoint(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def load_checkpoint(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Uses the base ``checkpoint_io`` to load the checkpoint.""" assert self.checkpoint_io is not None return self.checkpoint_io.load_checkpoint(*args, **kwargs) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index e63ccd6912b63..75e792af46b90 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -9,8 +9,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import torch from torch import Tensor @@ -121,12 +122,12 @@ def forward_context(self) -> Generator[None, None, None]: yield @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index e1e90281cf3af..9225e3bb9e7be 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities import apply_to_collection @@ -80,13 +80,13 @@ def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: if "true" not in self.precision: return nullcontext() return _DtypeContextManager(self._desired_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index 20f493bb7b2e2..efa1aa008a35e 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager -from typing import Any, ContextManager, Generator, Literal +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Literal import torch import torch.nn as nn @@ -37,11 +38,11 @@ def convert_module(self, module: nn.Module) -> nn.Module: return module.double() @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(torch.float64) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index e6c684967ed40..7029497c177cc 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Optional +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Any, Callable, Optional import torch from lightning_utilities import apply_to_collection @@ -109,15 +110,15 @@ def mixed_precision_config(self) -> "TorchMixedPrecision": ) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32) @override - def forward_context(self) -> ContextManager: + def forward_context(self) -> AbstractContextManager: if "mixed" in self.precision: return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16)) return _DtypeContextManager(self._desired_input_dtype) @@ -166,12 +167,12 @@ def optimizer_step( # type: ignore[override] return closure_result @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {} @override - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index 22dc29b580b95..fe9deb44c3653 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager -from typing import Any, ContextManager, Generator, Literal +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager +from typing import Any, Literal import torch from lightning_utilities import apply_to_collection @@ -43,11 +44,11 @@ def convert_module(self, module: Module) -> Module: return module.to(dtype=self._desired_input_dtype) @override - def tensor_init_context(self) -> ContextManager: + def tensor_init_context(self) -> AbstractContextManager: return _DtypeContextManager(self._desired_input_dtype) @override - def module_init_context(self) -> ContextManager: + def module_init_context(self) -> AbstractContextManager: return self.tensor_init_context() @override diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 51bdddb18f814..327fb2d4f5a27 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +from collections.abc import Generator from functools import partial -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch from torch import Tensor @@ -37,8 +38,8 @@ class Precision(FabricPrecision, CheckpointHooks): """ def connect( - self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[Module, List[Optimizer], List[Any]]: + self, model: Module, optimizers: list[Optimizer], lr_schedulers: list[Any] + ) -> tuple[Module, list[Optimizer], list[Any]]: """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 467b47124eb60..41681fbd239f3 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -20,7 +20,7 @@ import pstats import tempfile from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from typing_extensions import override @@ -66,7 +66,7 @@ def __init__( If you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: Dict[str, cProfile.Profile] = {} + self.profiled_actions: dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction self.dump_stats = dump_stats @@ -89,9 +89,10 @@ def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None: dst_fs = get_filesystem(dst_filepath) dst_fs.mkdirs(self.dirpath, exist_ok=True) # temporarily save to local since pstats can only dump into a local file - with tempfile.TemporaryDirectory( - prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd() - ) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file: + with ( + tempfile.TemporaryDirectory(prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()) as tmp_dir, + dst_fs.open(dst_filepath, "wb") as dst_file, + ): src_filepath = os.path.join(tmp_dir, "tmp.prof") profile.dump_stats(src_filepath) src_fs = get_filesystem(src_filepath) @@ -115,7 +116,7 @@ def teardown(self, stage: Optional[str]) -> None: super().teardown(stage=stage) self.profiled_actions = {} - def __reduce__(self) -> Tuple: + def __reduce__(self) -> tuple: # avoids `TypeError: cannot pickle 'cProfile.Profile' object` return ( self.__class__, diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index fb448321575a9..a09b703e606b8 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -16,9 +16,10 @@ import logging import os from abc import ABC, abstractmethod +from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union +from typing import Any, Callable, Optional, TextIO, Union from lightning.fabric.utilities.cloud_io import get_filesystem @@ -115,7 +116,7 @@ def describe(self) -> None: self._output_file.flush() self.teardown(stage=self._stage) - def _stats_to_str(self, stats: Dict[str, str]) -> str: + def _stats_to_str(self, stats: dict[str, str]) -> str: stage = f"{self._stage.upper()} " if self._stage is not None else "" output = [stage + "Profiler Report"] for action, value in stats.items(): diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index a26b3d321d2e0..e264d5154feba 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -16,9 +16,10 @@ import inspect import logging import os +from contextlib import AbstractContextManager from functools import lru_cache, partial from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from torch import Tensor, nn @@ -65,8 +66,8 @@ class RegisterRecordFunction: def __init__(self, model: nn.Module) -> None: self._model = model - self._records: Dict[str, record_function] = {} - self._handles: Dict[str, List[RemovableHandle]] = {} + self._records: dict[str, record_function] = {} + self._handles: dict[str, list[RemovableHandle]] = {} def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: # Add [pl][module] in name for pytorch profiler to recognize @@ -239,7 +240,7 @@ def __init__( row_limit: int = 20, sort_by_key: Optional[str] = None, record_module_names: bool = True, - table_kwargs: Optional[Dict[str, Any]] = None, + table_kwargs: Optional[dict[str, Any]] = None, **profiler_kwargs: Any, ) -> None: r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of @@ -305,8 +306,8 @@ def __init__( self.function_events: Optional[EventList] = None self._lightning_module: Optional[LightningModule] = None # set by ProfilerConnector self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[ContextManager] = None - self._recording_map: Dict[str, record_function] = {} + self._parent_profiler: Optional[AbstractContextManager] = None + self._recording_map: dict[str, record_function] = {} self._start_action_name: Optional[str] = None self._schedule: Optional[ScheduleWrapper] = None @@ -400,8 +401,8 @@ def _default_schedule() -> Optional[Callable]: return torch.profiler.schedule(wait=1, warmup=1, active=3) return None - def _default_activities(self) -> List["ProfilerActivity"]: - activities: List[ProfilerActivity] = [] + def _default_activities(self) -> list["ProfilerActivity"]: + activities: list[ProfilerActivity] = [] if not _KINETO_AVAILABLE: return activities if _TORCH_GREATER_EQUAL_2_4: @@ -530,7 +531,7 @@ def _create_profilers(self) -> None: torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile ) - def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + def _create_profiler(self, profiler: type[_PROFILER]) -> _PROFILER: init_parameters = inspect.signature(profiler.__init__).parameters kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} return profiler(**kwargs) diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py index eef7b12892faa..8a53965e3f487 100644 --- a/src/lightning/pytorch/profilers/simple.py +++ b/src/lightning/pytorch/profilers/simple.py @@ -18,7 +18,7 @@ import time from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import torch from typing_extensions import override @@ -27,10 +27,10 @@ log = logging.getLogger(__name__) -_TABLE_ROW_EXTENDED = Tuple[str, float, int, float, float] -_TABLE_DATA_EXTENDED = List[_TABLE_ROW_EXTENDED] -_TABLE_ROW = Tuple[str, float, float] -_TABLE_DATA = List[_TABLE_ROW] +_TABLE_ROW_EXTENDED = tuple[str, float, int, float, float] +_TABLE_DATA_EXTENDED = list[_TABLE_ROW_EXTENDED] +_TABLE_ROW = tuple[str, float, float] +_TABLE_DATA = list[_TABLE_ROW] class SimpleProfiler(Profiler): @@ -61,8 +61,8 @@ def __init__( if you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.current_actions: Dict[str, float] = {} - self.recorded_durations: Dict = defaultdict(list) + self.current_actions: dict[str, float] = {} + self.recorded_durations: dict = defaultdict(list) self.extended = extended self.start_time = time.monotonic() @@ -81,7 +81,7 @@ def stop(self, action_name: str) -> None: duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]: + def _make_report_extended(self) -> tuple[_TABLE_DATA_EXTENDED, float, float]: total_duration = time.monotonic() - self.start_time report = [] diff --git a/src/lightning/pytorch/profilers/xla.py b/src/lightning/pytorch/profilers/xla.py index a85f3a1295e78..3e810fbc4f096 100644 --- a/src/lightning/pytorch/profilers/xla.py +++ b/src/lightning/pytorch/profilers/xla.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict from typing_extensions import override @@ -45,8 +44,8 @@ def __init__(self, port: int = 9012) -> None: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(dirpath=None, filename=None) self.port = port - self._recording_map: Dict = {} - self._step_recoding_map: Dict = {} + self._recording_map: dict = {} + self._step_recoding_map: dict = {} self._start_trace: bool = False @override diff --git a/src/lightning/pytorch/serve/servable_module.py b/src/lightning/pytorch/serve/servable_module.py index f715f4b3cad9d..ed7a8a987898b 100644 --- a/src/lightning/pytorch/serve/servable_module.py +++ b/src/lightning/pytorch/serve/servable_module.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable import torch from torch import Tensor @@ -56,11 +56,11 @@ def configure_response(self): """ @abstractmethod - def configure_payload(self) -> Dict[str, Any]: + def configure_payload(self) -> dict[str, Any]: """Returns a request payload as a dictionary.""" @abstractmethod - def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callable]]: + def configure_serialization(self) -> tuple[dict[str, Callable], dict[str, Callable]]: """Returns a tuple of dictionaries. The first dictionary contains the name of the ``serve_step`` input variables name as its keys @@ -72,7 +72,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab """ @abstractmethod - def serve_step(self, *args: Tensor, **kwargs: Tensor) -> Dict[str, Tensor]: + def serve_step(self, *args: Tensor, **kwargs: Tensor) -> dict[str, Tensor]: r"""Returns the predictions of your model as a dictionary. .. code-block:: python @@ -90,5 +90,5 @@ def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ @abstractmethod - def configure_response(self) -> Dict[str, Any]: + def configure_response(self) -> dict[str, Any]: """Returns a response to validate the server response.""" diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py index 0acab203dedff..dc92625da357d 100644 --- a/src/lightning/pytorch/serve/servable_module_validator.py +++ b/src/lightning/pytorch/serve/servable_module_validator.py @@ -2,7 +2,7 @@ import logging import time from multiprocessing import Process -from typing import Any, Dict, Literal, Optional +from typing import Any, Literal, Optional import requests import torch @@ -136,7 +136,7 @@ def successful(self) -> Optional[bool]: return self.resp.status_code == 200 if self.resp else None @override - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"successful": self.successful, "optimization": self.optimization, "server": self.server} @staticmethod @@ -157,7 +157,7 @@ def ping() -> bool: return True @app.post("/serve") - async def serve(payload: dict = Body(...)) -> Dict[str, Any]: + async def serve(payload: dict = Body(...)) -> dict[str, Any]: body = payload["body"] for key, deserializer in deserializers.items(): diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 9031b6ee177f3..fd3f66ef42471 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -14,7 +14,7 @@ import logging from contextlib import nullcontext from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import torch import torch.distributed @@ -71,7 +71,7 @@ class DDPStrategy(ParallelStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -133,7 +133,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -283,7 +283,7 @@ def configure_ddp(self) -> None: self.model = self._setup_model(self.model) self._register_ddp_hooks() - def determine_ddp_device_ids(self) -> Optional[List[int]]: + def determine_ddp_device_ids(self) -> Optional[list[int]]: if self.root_device.type == "cpu": return None return [self.root_device.index] diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 1eaa5bab75fbe..4fa771114768d 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -17,9 +17,10 @@ import os import platform from collections import OrderedDict +from collections.abc import Generator, Mapping from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from torch.nn import Module @@ -102,9 +103,9 @@ def __init__( reduce_bucket_size: int = 200_000_000, zero_allow_untested_optimizer: bool = True, logging_batch_size_per_gpu: Union[str, int] = "auto", - config: Optional[Union[_PATH, Dict[str, Any]]] = None, + config: Optional[Union[_PATH, dict[str, Any]]] = None, logging_level: int = logging.WARN, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, initial_scale_power: int = 16, @@ -380,8 +381,8 @@ def restore_checkpoint_after_setup(self) -> bool: @override def _setup_model_and_optimizers( - self, model: Module, optimizers: List[Optimizer] - ) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]: + self, model: Module, optimizers: list[Optimizer] + ) -> tuple["deepspeed.DeepSpeedEngine", list[Optimizer]]: """Setup a model and multiple optimizers together. Currently only a single optimizer is supported. @@ -411,7 +412,7 @@ def _setup_model_and_optimizer( model: Module, optimizer: Optional[Optimizer], lr_scheduler: Optional[Union[LRScheduler, ReduceLROnPlateau]] = None, - ) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]: + ) -> tuple["deepspeed.DeepSpeedEngine", Optimizer]: """Initialize one model and one optimizer with an optional learning rate scheduler. This calls ``deepspeed.initialize`` internally. @@ -452,7 +453,7 @@ def init_deepspeed(self) -> None: else: self._initialize_deepspeed_inference(self.model) - def _init_optimizers(self) -> Tuple[Optimizer, Optional[LRSchedulerConfig]]: + def _init_optimizers(self) -> tuple[Optimizer, Optional[LRSchedulerConfig]]: assert self.lightning_module is not None optimizers, lr_schedulers = _init_optimizers_and_lr_schedulers(self.lightning_module) if len(optimizers) > 1 or len(lr_schedulers) > 1: @@ -572,7 +573,7 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @override @@ -608,7 +609,7 @@ def _multi_device(self) -> bool: return self.num_processes > 1 or self.num_nodes > 1 @override - def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: @@ -645,7 +646,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing @@ -708,9 +709,9 @@ def _restore_zero_state(self, ckpt: Mapping[str, Any], strict: bool) -> None: assert self.lightning_module is not None def load(module: torch.nn.Module, prefix: str = "") -> None: - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + error_msgs: list[str] = [] state_dict = ckpt["state_dict"] # copy state_dict so _load_from_state_dict can modify it @@ -780,7 +781,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: offload_optimizer_device="nvme", ) - def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Optional[Dict[str, Any]]: + def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]: if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") config = os.environ[self.DEEPSPEED_ENV_VAR] @@ -841,7 +842,7 @@ def _create_default_config( overlap_events: bool, thread_count: int, **zero_kwargs: Any, - ) -> Dict: + ) -> dict: cfg = { "activation_checkpointing": { "partition_activations": partition_activations, diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index ab6e579c3071f..bfbf99e82934c 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import shutil +from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path @@ -20,15 +21,8 @@ TYPE_CHECKING, Any, Callable, - Dict, - Generator, - List, Literal, - Mapping, Optional, - Set, - Tuple, - Type, Union, ) @@ -88,7 +82,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy - _POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] + _POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy] _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] @@ -148,12 +142,12 @@ class FSDPStrategy(ParallelStrategy): """ strategy_name = "fsdp" - _registered_strategies: List[str] = [] + _registered_strategies: list[str] = [] def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -162,11 +156,11 @@ def __init__( cpu_offload: Union[bool, "CPUOffload", None] = None, mixed_precision: Optional["MixedPrecision"] = None, auto_wrap_policy: Optional["_POLICY"] = None, - activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, + activation_checkpointing: Optional[Union[type[Module], list[type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "full", - device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None, + device_mesh: Optional[Union[tuple[int], "DeviceMesh"]] = None, **kwargs: Any, ) -> None: super().__init__( @@ -242,7 +236,7 @@ def precision_plugin(self, precision_plugin: Optional[FSDPPrecision]) -> None: @property @override - def distributed_sampler_kwargs(self) -> Dict: + def distributed_sampler_kwargs(self) -> dict: return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank} @property @@ -455,7 +449,7 @@ def reduce( return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def _determine_device_ids(self) -> List[int]: + def _determine_device_ids(self) -> list[int]: return [self.root_device.index] @override @@ -481,7 +475,7 @@ def teardown(self) -> None: self.accelerator.teardown() @classmethod - def get_registered_strategies(cls) -> List[str]: + def get_registered_strategies(cls) -> list[str]: return cls._registered_strategies @classmethod @@ -505,7 +499,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None: cls._registered_strategies.append("fsdp_cpu_offload") @override - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: assert self.model is not None if self._state_dict_type == "sharded": state_dict_ctx = _get_sharded_state_dict_context(self.model) @@ -522,7 +516,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr pass @override - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import OptimStateKeyType @@ -551,7 +545,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: @override def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: if storage_options is not None: raise TypeError( @@ -586,7 +580,7 @@ def save_checkpoint( raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") @override - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 05e3fed561ccb..aa207a527814e 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -18,7 +18,7 @@ import tempfile from contextlib import suppress from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, Union import torch import torch.backends.cudnn @@ -80,7 +80,7 @@ def __init__( f"The start method '{self._start_method}' is not available on this platform. Available methods are:" f" {', '.join(mp.get_all_start_methods())}" ) - self.procs: List[mp.Process] = [] + self.procs: list[mp.Process] = [] self._already_fit = False @property @@ -224,7 +224,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) - def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: + def get_extra_results(self, trainer: "pl.Trainer") -> dict[str, Any]: """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To avoid issues with memory sharing, we convert tensors to bytes. @@ -242,7 +242,7 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: # send tensors as bytes to avoid issues with memory sharing return {"callback_metrics_bytes": buffer.getvalue()} - def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None: + def update_main_process_results(self, trainer: "pl.Trainer", extra: dict[str, Any]) -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we convert bytes back to ``torch.Tensor``. @@ -265,7 +265,7 @@ def kill(self, signum: _SIGNUM) -> None: with suppress(ProcessLookupError): os.kill(proc.pid, signum) - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() state["procs"] = [] # SpawnProcess can't be pickled return state @@ -276,7 +276,7 @@ class _WorkerOutput(NamedTuple): weights_path: Optional[_PATH] trainer_state: TrainerState trainer_results: Any - extra: Dict[str, Any] + extra: dict[str, Any] @dataclass @@ -301,7 +301,7 @@ class _GlobalStateSnapshot: use_deterministic_algorithms: bool use_deterministic_algorithms_warn_only: bool cudnn_benchmark: bool - rng_states: Dict[str, Any] + rng_states: dict[str, Any] @classmethod def capture(cls) -> "_GlobalStateSnapshot": diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index d2035d03d2589..b7ec294c148d5 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -14,7 +14,7 @@ import logging import os import subprocess -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Optional from lightning_utilities.core.imports import RequirementCache from typing_extensions import override @@ -77,7 +77,7 @@ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, self.cluster_environment = cluster_environment self.num_processes = num_processes self.num_nodes = num_nodes - self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher + self.procs: list[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher @property @override diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index fb45166378c78..82fec205af731 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import shutil +from collections.abc import Generator, Mapping from contextlib import contextmanager, nullcontext from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only @@ -114,7 +115,7 @@ def num_processes(self) -> int: @property @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: assert self.device_mesh is not None data_parallel_mesh = self.device_mesh["data_parallel"] return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} @@ -237,7 +238,7 @@ def reduce( return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def _determine_device_ids(self) -> List[int]: + def _determine_device_ids(self) -> list[int]: return [self.root_device.index] @override @@ -249,7 +250,7 @@ def teardown(self) -> None: self.accelerator.teardown() @override - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: """Collects the state dict of the model. Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``. @@ -267,7 +268,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr pass @override - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]: + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Any]: """Collects the state of the given optimizer. Only returns a non-empty state dict on rank 0 if ``save_distributed_checkpoint=False``. @@ -296,7 +297,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: @override def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: if storage_options is not None: raise TypeError( @@ -328,7 +329,7 @@ def save_checkpoint( return super().save_checkpoint(checkpoint=checkpoint, filepath=path) @override - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: # broadcast the path from rank 0 to ensure all the states are loaded from a common path path = Path(self.broadcast(checkpoint_path)) state = { diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index 5658438cd3f53..285d40706a5a9 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -33,7 +34,7 @@ class ParallelStrategy(Strategy, ABC): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[Precision] = None, @@ -71,15 +72,15 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self) -> Optional[List[torch.device]]: + def parallel_devices(self) -> Optional[list[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: + def parallel_devices(self, parallel_devices: Optional[list[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Dict[str, Any]: + def distributed_sampler_kwargs(self) -> dict[str, Any]: return { "num_replicas": len(self.parallel_devices) if self.parallel_devices is not None else 0, "rank": self.global_rank, diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 314007f497f59..0a0f52e906dd5 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -13,8 +13,9 @@ # limitations under the License. import logging from abc import ABC, abstractmethod +from collections.abc import Generator, Mapping from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import torch from torch import Tensor @@ -61,9 +62,9 @@ def __init__( self._model: Optional[Module] = None self._launcher: Optional[_Launcher] = None self._forward_redirection: _ForwardRedirection = _ForwardRedirection() - self._optimizers: List[Optimizer] = [] - self._lightning_optimizers: List[LightningOptimizer] = [] - self.lr_scheduler_configs: List[LRSchedulerConfig] = [] + self._optimizers: list[Optimizer] = [] + self._lightning_optimizers: list[LightningOptimizer] = [] + self.lr_scheduler_configs: list[LRSchedulerConfig] = [] @property def launcher(self) -> Optional[_Launcher]: @@ -99,11 +100,11 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None: self._precision_plugin = precision_plugin @property - def optimizers(self) -> List[Optimizer]: + def optimizers(self) -> list[Optimizer]: return self._optimizers @optimizers.setter - def optimizers(self, optimizers: List[Optimizer]) -> None: + def optimizers(self, optimizers: list[Optimizer]) -> None: self._optimizers = optimizers self._lightning_optimizers = [LightningOptimizer._to_lightning_optimizer(opt, self) for opt in optimizers] @@ -170,7 +171,7 @@ def setup_precision_plugin(self) -> None: self.optimizers = optimizers self.lr_scheduler_configs = lr_scheduler_configs - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> dict[str, Tensor]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom strategies. @@ -237,7 +238,7 @@ def optimizer_step( assert isinstance(model, pl.LightningModule) return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) - def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: + def _setup_model_and_optimizers(self, model: Module, optimizers: list[Optimizer]) -> tuple[Module, list[Optimizer]]: """Setup a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will @@ -362,7 +363,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" return self._lightning_module - def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) @@ -470,13 +471,13 @@ def handles_gradient_accumulation(self) -> bool: """Whether the strategy handles gradient accumulation internally.""" return False - def lightning_module_state_dict(self) -> Dict[str, Any]: + def lightning_module_state_dict(self) -> dict[str, Any]: """Returns model state.""" assert self.lightning_module is not None return self.lightning_module.state_dict() def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -587,13 +588,13 @@ def _reset_optimizers_and_schedulers(self) -> None: self._lightning_optimizers = [] self.lr_scheduler_configs = [] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: # `LightningOptimizer` overrides `self.__class__` so they cannot be pickled state = dict(vars(self)) # copy state["_lightning_optimizers"] = [] return state - def __setstate__(self, state: Dict) -> None: + def __setstate__(self, state: dict) -> None: self.__dict__ = state self.optimizers = self.optimizers # re-create the `_lightning_optimizers` diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 56aae90c56897..faffb30d6256f 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from torch import Tensor @@ -49,7 +49,7 @@ class XLAStrategy(DDPStrategy): def __init__( self, accelerator: Optional["pl.accelerators.Accelerator"] = None, - parallel_devices: Optional[List[torch.device]] = None, + parallel_devices: Optional[list[torch.device]] = None, checkpoint_io: Optional[Union[XLACheckpointIO, _WrappingCheckpointIO]] = None, precision_plugin: Optional[XLAPrecision] = None, debug: bool = False, @@ -172,7 +172,7 @@ def _setup_model(self, model: Module) -> Module: # type: ignore @property @override - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> dict[str, int]: return {"num_replicas": self.world_size, "rank": self.global_rank} @override @@ -295,7 +295,7 @@ def set_world_ranks(self) -> None: @override def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + self, checkpoint: dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: import torch_xla.core.xla_model as xm diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 4c3bc5ef41bdd..012d1a2152aa3 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -14,7 +14,7 @@ import logging import signal from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Optional, Union from packaging.version import Version @@ -115,7 +115,11 @@ def _call_configure_model(trainer: "pl.Trainer") -> None: # we don't normally check for this before calling the hook. it is done here to avoid instantiating the context # managers if is_overridden("configure_model", trainer.lightning_module): - with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501 + with ( + trainer.strategy.tensor_init_context(), + trainer.strategy.model_sharded_context(), + trainer.precision_plugin.module_init_context(), + ): _call_lightning_module_hook(trainer, "configure_model") @@ -222,7 +226,7 @@ def _call_callback_hooks( pl_module._current_fx_name = prev_fx_name -def _call_callbacks_state_dict(trainer: "pl.Trainer") -> Dict[str, dict]: +def _call_callbacks_state_dict(trainer: "pl.Trainer") -> dict[str, dict]: """Called when saving a model checkpoint, calls and returns every callback's `state_dict`, keyed by `Callback.state_key`.""" callback_state_dicts = {} @@ -233,7 +237,7 @@ def _call_callbacks_state_dict(trainer: "pl.Trainer") -> Dict[str, dict]: return callback_state_dicts -def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.""" pl_module = trainer.lightning_module if pl_module: @@ -249,7 +253,7 @@ def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint. Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using @@ -261,7 +265,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = "on_load_checkpoint" - callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") + callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") if callback_states is None: return @@ -285,9 +289,9 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s pl_module._current_fx_name = prev_fx_name -def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: Dict[str, Any]) -> None: +def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" - callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") + callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks") if callback_states is None: return diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 06f3ee366bcaa..7af8f13fce38a 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -15,7 +15,7 @@ import logging import os from collections import Counter -from typing import Dict, List, Literal, Optional, Union +from typing import Literal, Optional, Union import torch @@ -74,11 +74,11 @@ class _AcceleratorConnector: def __init__( self, - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, precision: Optional[_PRECISION_INPUT] = None, sync_batchnorm: bool = False, benchmark: Optional[bool] = None, @@ -123,7 +123,7 @@ def __init__( self._precision_flag: _PRECISION_INPUT_STR = "32-true" self._precision_plugin_flag: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None - self._parallel_devices: List[Union[int, torch.device, str]] = [] + self._parallel_devices: list[Union[int, torch.device, str]] = [] self._layer_sync: Optional[LayerSync] = TorchSyncBatchNorm() if sync_batchnorm else None self.checkpoint_io: Optional[CheckpointIO] = None @@ -166,7 +166,7 @@ def _check_config_and_set_final_flags( strategy: Union[str, Strategy], accelerator: Union[str, Accelerator], precision: Optional[_PRECISION_INPUT], - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]], sync_batchnorm: bool, ) -> None: """This method checks: @@ -225,7 +225,7 @@ def _check_config_and_set_final_flags( precision_flag = _convert_precision_to_unified_args(precision) if plugins: - plugins_flags_types: Dict[str, int] = Counter() + plugins_flags_types: dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Precision): self._precision_plugin_flag = plugin @@ -310,7 +310,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: + def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str, int], num_nodes: int) -> None: if not isinstance(num_nodes, int) or num_nodes < 1: raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 2f2b619290ae5..a60f907d9361b 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -14,8 +14,9 @@ import logging import os +from collections.abc import Sequence from datetime import timedelta -from typing import Dict, List, Optional, Sequence, Union +from typing import Optional, Union import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks @@ -46,12 +47,12 @@ def __init__(self, trainer: "pl.Trainer"): def on_trainer_init( self, - callbacks: Optional[Union[List[Callback], Callback]], + callbacks: Optional[Union[list[Callback], Callback]], enable_checkpointing: bool, enable_progress_bar: bool, default_root_dir: Optional[str], enable_model_summary: bool, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, ) -> None: # init folder paths for checkpoint + weights save callbacks self.trainer._default_root_dir = default_root_dir or os.getcwd() @@ -139,7 +140,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: progress_bar_callback = TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) - def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: + def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None: if max_time is None: return if any(isinstance(cb, Timer) for cb in self.trainer.callbacks): @@ -195,7 +196,7 @@ def _attach_model_callbacks(self) -> None: trainer.callbacks = all_callbacks @staticmethod - def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: + def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. @@ -208,9 +209,9 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: if there were any present in the input. """ - tuner_callbacks: List[Callback] = [] - other_callbacks: List[Callback] = [] - checkpoint_callbacks: List[Callback] = [] + tuner_callbacks: list[Callback] = [] + other_callbacks: list[Callback] = [] + checkpoint_callbacks: list[Callback] = [] for cb in callbacks: if isinstance(cb, (BatchSizeFinder, LearningRateFinder)): @@ -223,7 +224,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: return tuner_callbacks + other_callbacks + checkpoint_callbacks -def _validate_callbacks_list(callbacks: List[Callback]) -> None: +def _validate_callbacks_list(callbacks: list[Callback]) -> None: stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] seen_callbacks = set() for callback in stateful_callbacks: diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index a41f87d418ebe..71cc5a14686be 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -14,7 +14,7 @@ import logging import os import re -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from fsspec.core import url_to_fs @@ -44,7 +44,7 @@ def __init__(self, trainer: "pl.Trainer") -> None: self._ckpt_path: Optional[_PATH] = None # flag to know if the user is changing the checkpoint path statefully. See `trainer.ckpt_path.setter` self._user_managed: bool = False - self._loaded_checkpoint: Dict[str, Any] = {} + self._loaded_checkpoint: dict[str, Any] = {} @property def _hpc_resume_path(self) -> Optional[str]: @@ -491,10 +491,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint) return checkpoint - def _get_lightning_module_state_dict(self) -> Dict[str, Tensor]: + def _get_lightning_module_state_dict(self) -> dict[str, Tensor]: return self.trainer.strategy.lightning_module_state_dict() - def _get_loops_state_dict(self) -> Dict[str, Any]: + def _get_loops_state_dict(self) -> dict[str, Any]: return { "fit_loop": self.trainer.fit_loop.state_dict(), "validate_loop": self.trainer.validate_loop.state_dict(), diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 1e84a2ebd0244..3e5273085ed2b 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Iterable, Optional, Tuple, Union +from typing import Any, Optional, Union import torch.multiprocessing as mp from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler @@ -342,7 +343,7 @@ class _DataHookSelector: model: "pl.LightningModule" datamodule: Optional["pl.LightningDataModule"] - _valid_hooks: Tuple[str, ...] = field( + _valid_hooks: tuple[str, ...] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index 545749bc5b321..0dbdc4eaf76e1 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Union from typing_extensions import TypedDict @@ -20,8 +20,8 @@ class _FxValidator: class _LogOptions(TypedDict): - allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]] - allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]] + allowed_on_step: Union[tuple[bool], tuple[bool, bool]] + allowed_on_epoch: Union[tuple[bool], tuple[bool, bool]] default_on_step: bool default_on_epoch: bool @@ -166,7 +166,7 @@ def check_logging(cls, fx_name: str) -> None: @classmethod def get_default_logging_levels( cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> Tuple[bool, bool]: + ) -> tuple[bool, bool]: """Return default logging levels for given hook.""" fx_config = cls.functions[fx_name] assert fx_config is not None @@ -191,7 +191,7 @@ def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> No @classmethod def check_logging_and_get_default_levels( cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] - ) -> Tuple[bool, bool]: + ) -> tuple[bool, bool]: """Check if the given hook name is allowed to log and return logging levels.""" cls.check_logging(fx_name) on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index c4ab11632b56b..ffc99a9772469 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Optional, Union +from collections.abc import Iterable +from typing import Any, Optional, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 62cc7844d3897..fdde19aa80eea 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Optional, Union, cast import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -32,8 +33,8 @@ from lightning.pytorch.utilities.warnings import PossibleUserWarning _VALUE = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors -_OUT_DICT = Dict[str, Tensor] -_PBAR_DICT = Dict[str, float] +_OUT_DICT = dict[str, Tensor] +_PBAR_DICT = dict[str, float] class _METRICS(TypedDict): @@ -333,7 +334,7 @@ def __init__(self, training: bool) -> None: self.dataloader_idx: Optional[int] = None @property - def result_metrics(self) -> List[_ResultMetric]: + def result_metrics(self) -> list[_ResultMetric]: return list(self.values()) def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], meta: _Metadata) -> int: @@ -456,7 +457,7 @@ def valid_items(self) -> Generator: """This function is used to iterate over current valid metrics.""" return ((k, v) for k, v in self.items() if not v.has_reset and self.dataloader_idx == v.meta.dataloader_idx) - def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]: + def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> tuple[str, str]: name = result_metric.meta.name forked_name = result_metric.meta.forked_name(on_step) add_dataloader_idx = result_metric.meta.add_dataloader_idx diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 05a975326005f..e63fecd3897f2 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -5,7 +5,7 @@ import threading from subprocess import call from types import FrameType -from typing import Any, Callable, Dict, List, Set, Union +from typing import Any, Callable, Union import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment @@ -20,7 +20,7 @@ class _HandlersCompose: - def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None: + def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None: if not isinstance(signal_handlers, list): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers @@ -37,14 +37,14 @@ class _SignalConnector: def __init__(self, trainer: "pl.Trainer") -> None: self.received_sigterm = False self.trainer = trainer - self._original_handlers: Dict[_SIGNUM, _HANDLER] = {} + self._original_handlers: dict[_SIGNUM, _HANDLER] = {} def register_signal_handlers(self) -> None: self.received_sigterm = False self._original_handlers = self._get_current_signal_handlers() - sigusr_handlers: List[_HANDLER] = [] - sigterm_handlers: List[_HANDLER] = [self._sigterm_notifier_fn] + sigusr_handlers: list[_HANDLER] = [] + sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn] environment = self.trainer._accelerator_connector.cluster_environment if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: @@ -123,7 +123,7 @@ def teardown(self) -> None: self._original_handlers = {} @staticmethod - def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: + def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]: """Collects the currently assigned signal handlers.""" valid_signals = _SignalConnector._valid_signals() if not _IS_WINDOWS: @@ -132,7 +132,7 @@ def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: return {signum: signal.getsignal(signum) for signum in valid_signals} @staticmethod - def _valid_signals() -> Set[signal.Signals]: + def _valid_signals() -> set[signal.Signals]: """Returns all valid signals supported on the current platform.""" return signal.valid_signals() @@ -145,7 +145,7 @@ def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: if threading.current_thread() is threading.main_thread(): signal.signal(signum, handlers) # type: ignore[arg-type] - def __getstate__(self) -> Dict: + def __getstate__(self) -> dict: state = self.__dict__.copy() state["_original_handlers"] = {} return state diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 23db90fd45b38..0509f28acb07a 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -23,9 +23,10 @@ import logging import math import os +from collections.abc import Generator, Iterable from contextlib import contextmanager from datetime import timedelta -from typing import Any, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Optional, Union from weakref import proxy import torch @@ -90,17 +91,17 @@ def __init__( *, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", - devices: Union[List[int], str, int] = "auto", + devices: Union[list[int], str, int] = "auto", num_nodes: int = 1, precision: Optional[_PRECISION_INPUT] = None, logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, - callbacks: Optional[Union[List[Callback], Callback]] = None, + callbacks: Optional[Union[list[Callback], Callback]] = None, fast_dev_run: Union[int, bool] = False, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: int = -1, min_steps: Optional[int] = None, - max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, + max_time: Optional[Union[str, timedelta, dict[str, int]]] = None, limit_train_batches: Optional[Union[int, float]] = None, limit_val_batches: Optional[Union[int, float]] = None, limit_test_batches: Optional[Union[int, float]] = None, @@ -123,7 +124,7 @@ def __init__( profiler: Optional[Union[Profiler, str]] = None, detect_anomaly: bool = False, barebones: bool = False, - plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, + plugins: Optional[Union[_PLUGIN_INPUT, list[_PLUGIN_INPUT]]] = None, sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, @@ -472,7 +473,7 @@ def __init__( setup._init_profiler(self, profiler) # init logger flags - self._loggers: List[Logger] + self._loggers: list[Logger] self._logger_connector.on_trainer_init(logger, log_every_n_steps) # init debugging flags @@ -1149,7 +1150,7 @@ def num_nodes(self) -> int: return getattr(self.strategy, "num_nodes", 1) @property - def device_ids(self) -> List[int]: + def device_ids(self) -> list[int]: """List of device indexes per node.""" devices = ( self.strategy.parallel_devices @@ -1176,15 +1177,15 @@ def lightning_module(self) -> "pl.LightningModule": return self.strategy.lightning_module # type: ignore[return-value] @property - def optimizers(self) -> List[Optimizer]: + def optimizers(self) -> list[Optimizer]: return self.strategy.optimizers @optimizers.setter - def optimizers(self, new_optims: List[Optimizer]) -> None: + def optimizers(self, new_optims: list[Optimizer]) -> None: self.strategy.optimizers = new_optims @property - def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: + def lr_scheduler_configs(self) -> list[LRSchedulerConfig]: return self.strategy.lr_scheduler_configs @property @@ -1247,7 +1248,7 @@ def training_step(self, batch, batch_idx): return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: + def distributed_sampler_kwargs(self) -> Optional[dict[str, Any]]: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs return None @@ -1280,7 +1281,7 @@ def early_stopping_callback(self) -> Optional[EarlyStopping]: return callbacks[0] if len(callbacks) > 0 else None @property - def early_stopping_callbacks(self) -> List[EarlyStopping]: + def early_stopping_callbacks(self) -> list[EarlyStopping]: """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @@ -1293,7 +1294,7 @@ def checkpoint_callback(self) -> Optional[Checkpoint]: return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> List[Checkpoint]: + def checkpoint_callbacks(self) -> list[Checkpoint]: """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, Checkpoint)] @@ -1522,14 +1523,14 @@ def num_training_batches(self) -> Union[int, float]: return self.fit_loop.max_batches @property - def num_sanity_val_batches(self) -> List[Union[int, float]]: + def num_sanity_val_batches(self) -> list[Union[int, float]]: """The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches # re-compute the `min` in case this is called outside the sanity-checking stage return [min(self.num_sanity_val_steps, batches) for batches in max_batches] @property - def num_val_batches(self) -> List[Union[int, float]]: + def num_val_batches(self) -> list[Union[int, float]]: """The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop.max_batches @@ -1538,12 +1539,12 @@ def num_val_batches(self) -> List[Union[int, float]]: return self.fit_loop.epoch_loop.val_loop._max_batches @property - def num_test_batches(self) -> List[Union[int, float]]: + def num_test_batches(self) -> list[Union[int, float]]: """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches @property - def num_predict_batches(self) -> List[Union[int, float]]: + def num_predict_batches(self) -> list[Union[int, float]]: """The number of prediction batches that will be used during ``trainer.predict()``.""" return self.predict_loop.max_batches @@ -1584,7 +1585,7 @@ def logger(self, logger: Optional[Logger]) -> None: self.loggers = [logger] @property - def loggers(self) -> List[Logger]: + def loggers(self) -> list[Logger]: """The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python @@ -1596,7 +1597,7 @@ def loggers(self) -> List[Logger]: return self._loggers @loggers.setter - def loggers(self, loggers: Optional[List[Logger]]) -> None: + def loggers(self, loggers: Optional[list[Logger]]) -> None: self._loggers = loggers if loggers else [] @property diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 6618f7e930ca1..99badd84bb8ad 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -15,7 +15,7 @@ import os import uuid from copy import deepcopy -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import lightning.pytorch as pl from lightning.pytorch.utilities.memory import garbage_collection_cuda, is_oom_error @@ -98,7 +98,7 @@ def _scale_batch_size( return new_size -def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: +def __scale_batch_dump_params(trainer: "pl.Trainer") -> dict[str, Any]: dumped_params = { "loggers": trainer.loggers, "callbacks": trainer.callbacks, @@ -138,7 +138,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N loop.verbose = False -def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def __scale_batch_restore_params(trainer: "pl.Trainer", params: dict[str, Any]) -> None: # TODO: There are more states that needs to be reset (#4512 and #4870) trainer.loggers = params["loggers"] trainer.callbacks = params["callbacks"] @@ -169,7 +169,7 @@ def _run_power_scaling( new_size: int, batch_arg_name: str, max_trials: int, - params: Dict[str, Any], + params: dict[str, Any], ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -211,7 +211,7 @@ def _run_binary_scaling( new_size: int, batch_arg_name: str, max_trials: int, - params: Dict[str, Any], + params: dict[str, Any], ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. @@ -276,7 +276,7 @@ def _adjust_batch_size( factor: float = 1.0, value: Optional[int] = None, desc: Optional[str] = None, -) -> Tuple[int, bool]: +) -> tuple[int, bool]: """Helper function for adjusting the batch size. Args: @@ -328,7 +328,7 @@ def _reset_dataloaders(trainer: "pl.Trainer") -> None: loop.epoch_loop.val_loop.setup_data() -def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: dict[str, Any]) -> None: loop = trainer._active_loop assert loop is not None loop.load_state_dict(deepcopy(params["loop_state_dict"])) diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index d756d3d76597c..b50bedb10d53f 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -101,7 +101,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) - self.lr_max = lr_max self.num_training = num_training - self.results: Dict[str, Any] = {} + self.results: dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: @@ -310,7 +310,7 @@ def _lr_find( return lr_finder -def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: +def __lr_finder_dump_params(trainer: "pl.Trainer") -> dict[str, Any]: return { "optimizers": trainer.strategy.optimizers, "lr_scheduler_configs": trainer.strategy.lr_scheduler_configs, @@ -335,7 +335,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto trainer.limit_val_batches = num_training -def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def __lr_finder_restore_params(trainer: "pl.Trainer", params: dict[str, Any]) -> None: trainer.strategy.optimizers = params["optimizers"] trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"] trainer.callbacks = params["callbacks"] @@ -376,8 +376,8 @@ def __init__( self.num_training = num_training self.early_stop_threshold = early_stop_threshold self.beta = beta - self.losses: List[float] = [] - self.lrs: List[float] = [] + self.losses: list[float] = [] + self.lrs: list[float] = [] self.avg_loss = 0.0 self.best_loss = 0.0 self.progress_bar_refresh_rate = progress_bar_refresh_rate @@ -463,7 +463,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -475,7 +475,7 @@ def get_lr(self) -> List[float]: return val @property - def lr(self) -> Union[float, List[float]]: + def lr(self) -> Union[float, list[float]]: return self._lr @@ -500,7 +500,7 @@ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: in super().__init__(optimizer, last_epoch) @override - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -512,11 +512,11 @@ def get_lr(self) -> List[float]: return val @property - def lr(self) -> Union[float, List[float]]: + def lr(self) -> Union[float, list[float]]: return self._lr -def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: dict[str, Any]) -> None: loop = trainer.fit_loop loop.load_state_dict(deepcopy(params["loop_state_dict"])) loop.restarting = False diff --git a/src/lightning/pytorch/utilities/_pytree.py b/src/lightning/pytorch/utilities/_pytree.py index f5f48b481c879..a0c7236cb27f1 100644 --- a/src/lightning/pytorch/utilities/_pytree.py +++ b/src/lightning/pytorch/utilities/_pytree.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import Any from torch.utils._pytree import SUPPORTED_NODES, LeafSpec, PyTree, TreeSpec, _get_node_type, tree_unflatten @@ -15,7 +15,7 @@ def _is_leaf_or_primitive_container(pytree: PyTree) -> bool: return all(isinstance(child, (int, float, str)) for child in child_pytrees) -def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: +def _tree_flatten(pytree: PyTree) -> tuple[list[Any], TreeSpec]: """Copy of :func:`torch.utils._pytree.tree_flatten` using our custom leaf function.""" if _is_leaf_or_primitive_container(pytree): return [pytree], LeafSpec() @@ -24,8 +24,8 @@ def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(pytree) - result: List[Any] = [] - children_specs: List[TreeSpec] = [] + result: list[Any] = [] + children_specs: list[TreeSpec] = [] for child in child_pytrees: flat, child_spec = _tree_flatten(child) result += flat @@ -34,6 +34,6 @@ def _tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: return result, TreeSpec(node_type, context, children_specs) -def _map_and_unflatten(fn: Any, values: List[Any], spec: TreeSpec) -> PyTree: +def _map_and_unflatten(fn: Any, values: list[Any], spec: TreeSpec) -> PyTree: """Utility function to apply a function and unflatten it.""" return tree_unflatten([fn(i) for i in values], spec) diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index eb7273b54a577..1e01297248ffa 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -19,12 +19,12 @@ from ast import literal_eval from contextlib import suppress from functools import wraps -from typing import Any, Callable, Type, TypeVar, cast +from typing import Any, Callable, TypeVar, cast _T = TypeVar("_T", bound=Callable[..., Any]) -def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def _parse_env_variables(cls: type, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Examples: diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 9b0ceb0288e87..9c89c998aa913 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections.abc import Iterable -from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Literal, Optional, Union from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict, override @@ -22,15 +22,15 @@ from lightning.fabric.utilities.types import _Stateful from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten -_ITERATOR_RETURN = Tuple[Any, int, int] # batch, batch_idx, dataloader_idx +_ITERATOR_RETURN = tuple[Any, int, int] # batch, batch_idx, dataloader_idx class _ModeIterator(Iterator[_ITERATOR_RETURN]): - def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: if limits is not None and len(limits) != len(iterables): raise ValueError(f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(iterables)})") self.iterables = iterables - self.iterators: List[Iterator] = [] + self.iterators: list[Iterator] = [] self._idx = 0 # what would be batch_idx self.limits = limits @@ -51,7 +51,7 @@ def reset(self) -> None: self.iterators = [] self._idx = 0 - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() # workaround an inconvenient `NotImplementedError`: @@ -65,9 +65,9 @@ def __getstate__(self) -> Dict[str, Any]: class _MaxSizeCycle(_ModeIterator): - def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: super().__init__(iterables, limits) - self._consumed: List[bool] = [] + self._consumed: list[bool] = [] @override def __next__(self) -> _ITERATOR_RETURN: @@ -121,7 +121,7 @@ def __len__(self) -> int: class _Sequential(_ModeIterator): - def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: + def __init__(self, iterables: list[Iterable], limits: Optional[list[Union[int, float]]] = None) -> None: super().__init__(iterables, limits) self._iterator_idx = 0 # what would be dataloader_idx @@ -206,8 +206,8 @@ def __len__(self) -> int: class _CombinationMode(TypedDict): - fn: Callable[[List[int]], int] - iterator: Type[_ModeIterator] + fn: Callable[[list[int]], int] + iterator: type[_ModeIterator] _SUPPORTED_MODES = { @@ -288,7 +288,7 @@ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode self._iterator: Optional[_ModeIterator] = None - self._limits: Optional[List[Union[int, float]]] = None + self._limits: Optional[list[Union[int, float]]] = None @property def iterables(self) -> Any: @@ -306,12 +306,12 @@ def batch_sampler(self) -> Any: return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self.flattened, self._spec) @property - def flattened(self) -> List[Any]: + def flattened(self) -> list[Any]: """Return the flat list of iterables.""" return self._flattened @flattened.setter - def flattened(self, flattened: List[Any]) -> None: + def flattened(self, flattened: list[Any]) -> None: """Setter to conveniently update the list of iterables.""" if len(flattened) != len(self._flattened): raise ValueError( @@ -322,12 +322,12 @@ def flattened(self, flattened: List[Any]) -> None: self._flattened = flattened @property - def limits(self) -> Optional[List[Union[int, float]]]: + def limits(self) -> Optional[list[Union[int, float]]]: """Optional limits per iterator.""" return self._limits @limits.setter - def limits(self, limits: Optional[Union[int, float, List[Union[int, float]]]]) -> None: + def limits(self, limits: Optional[Union[int, float, list[Union[int, float]]]]) -> None: if isinstance(limits, (int, float)): limits = [limits] * len(self.flattened) elif isinstance(limits, list) and len(limits) != len(self.flattened): @@ -375,11 +375,11 @@ def _dataset_length(self) -> int: fn = _SUPPORTED_MODES[self._mode]["fn"] return fn(lengths) - def _state_dicts(self) -> List[Dict[str, Any]]: + def _state_dicts(self) -> list[dict[str, Any]]: """Returns the list of state dicts for iterables in `self.flattened` that are stateful.""" return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)] - def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None: + def _load_state_dicts(self, states: list[dict[str, Any]]) -> None: """Loads the state dicts for iterables in `self.flattened` that are stateful.""" if not states: return @@ -401,5 +401,5 @@ def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: dataloader._iterator = None -def _get_iterables_lengths(iterables: List[Iterable]) -> List[Union[int, float]]: +def _get_iterables_lengths(iterables: list[Iterable]) -> list[Union[int, float]]: return [(float("inf") if (length := sized_len(iterable)) is None else length) for iterable in iterables] diff --git a/src/lightning/pytorch/utilities/consolidate_checkpoint.py b/src/lightning/pytorch/utilities/consolidate_checkpoint.py index 6f150bab0f23c..0dcf5879b6fc5 100644 --- a/src/lightning/pytorch/utilities/consolidate_checkpoint.py +++ b/src/lightning/pytorch/utilities/consolidate_checkpoint.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict +from typing import Any import torch @@ -7,7 +7,7 @@ from lightning.fabric.utilities.load import _load_distributed_checkpoint -def _format_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: +def _format_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]: """Converts the special FSDP checkpoint format to the standard format the Lightning Trainer can load.""" # Rename the model key checkpoint["state_dict"] = checkpoint.pop("model") diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index b58142b3a4012..5c14561f7aff9 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections.abc import Generator, Iterable, Mapping, Sized from dataclasses import fields -from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Sized, Tuple, Union +from typing import Any, Optional, Union import torch from lightning_utilities.core.apply_func import is_dataclass_instance @@ -139,7 +140,7 @@ def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, -) -> Tuple[Tuple[Any], Dict[str, Any]]: +) -> tuple[tuple[Any], dict[str, Any]]: if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") @@ -233,7 +234,7 @@ def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation. diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py index f200d892474db..08a0230f759cf 100644 --- a/src/lightning/pytorch/utilities/grads.py +++ b/src/lightning/pytorch/utilities/grads.py @@ -13,13 +13,13 @@ # limitations under the License. """Utilities to describe gradients.""" -from typing import Dict, Union +from typing import Union import torch from torch.nn import Module -def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> Dict[str, float]: +def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator: str = "/") -> dict[str, float]: """Compute each parameter's gradient's norm and their overall norm. The overall norm is computed over all gradients together, as if they diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 6a5a914bed9ba..5db942b29183f 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -31,17 +31,17 @@ """ import re -from typing import Any, Callable, Dict, List +from typing import Any, Callable from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks.early_stopping import EarlyStopping from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.utilities.rank_zero import rank_zero_warn -_CHECKPOINT = Dict[str, Any] +_CHECKPOINT = dict[str, Any] -def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: +def _migration_index() -> dict[str, list[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], @@ -133,7 +133,7 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: return checkpoint -def _get_fit_loop_initial_state_1_6_0() -> Dict: +def _get_fit_loop_initial_state_1_6_0() -> dict: return { "epoch_loop.batch_loop.manual_loop.optim_step_progress": { "current": {"completed": 0, "ready": 0}, diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 1537c2684fbe4..2c5656e1f1016 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -18,7 +18,7 @@ import threading import warnings from types import ModuleType, TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional from packaging.version import Version from typing_extensions import override @@ -32,13 +32,13 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn _log = logging.getLogger(__name__) -_CHECKPOINT = Dict[str, Any] +_CHECKPOINT = dict[str, Any] _lock = threading.Lock() def migrate_checkpoint( checkpoint: _CHECKPOINT, target_version: Optional[str] = None -) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: +) -> tuple[_CHECKPOINT, dict[str, list[str]]]: """Applies Lightning version migrations to a checkpoint dictionary. Args: @@ -121,7 +121,7 @@ class _FaultTolerantMode(LightningEnum): def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> None: diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index 36adedf4a831b..44591aa7f4dc1 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -15,7 +15,7 @@ import inspect import logging import os -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar from lightning_utilities.core.imports import RequirementCache from torch import nn @@ -26,7 +26,7 @@ _log = logging.getLogger(__name__) -def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None) -> bool: +def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[type[object]] = None) -> bool: if instance is None: # if `self.lightning_module` was passed as instance, it can be `None` return False @@ -65,7 +65,7 @@ class _ModuleMode: """Captures the ``nn.Module.training`` (bool) mode of every submodule, and allows it to be restored later on.""" def __init__(self) -> None: - self.mode: Dict[str, bool] = {} + self.mode: dict[str, bool] = {} def capture(self, module: nn.Module) -> None: self.mode.clear() @@ -108,10 +108,10 @@ class _restricted_classmethod_impl(Generic[_T, _P, _R_co]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" - def __init__(self, method: Callable[Concatenate[Type[_T], _P], _R_co]) -> None: + def __init__(self, method: Callable[Concatenate[type[_T], _P], _R_co]) -> None: self.method = method - def __get__(self, instance: Optional[_T], cls: Type[_T]) -> Callable[_P, _R_co]: + def __get__(self, instance: Optional[_T], cls: type[_T]) -> Callable[_P, _R_co]: # The wrapper ensures that the method can be inspected, but not called on an instance @functools.wraps(self.method) def wrapper(*args: Any, **kwargs: Any) -> _R_co: diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index c40dc94568a51..6a5baf2c1e04a 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -17,7 +17,7 @@ import logging import math from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -73,8 +73,8 @@ def __init__(self, module: nn.Module) -> None: super().__init__() self._module = module self._hook_handle = self._register_hook() - self._in_size: Optional[Union[str, List]] = None - self._out_size: Optional[Union[str, List]] = None + self._in_size: Optional[Union[str, list]] = None + self._out_size: Optional[Union[str, list]] = None def __del__(self) -> None: self.detach_hook() @@ -121,11 +121,11 @@ def detach_hook(self) -> None: self._hook_handle.remove() @property - def in_size(self) -> Union[str, List]: + def in_size(self) -> Union[str, list]: return self._in_size or UNKNOWN_SIZE @property - def out_size(self) -> Union[str, List]: + def out_size(self) -> Union[str, list]: return self._out_size or UNKNOWN_SIZE @property @@ -221,8 +221,8 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: self._precision_megabytes = (precision / 8.0) * 1e-6 @property - def named_modules(self) -> List[Tuple[str, nn.Module]]: - mods: List[Tuple[str, nn.Module]] + def named_modules(self) -> list[tuple[str, nn.Module]]: + mods: list[tuple[str, nn.Module]] if self._max_depth == 0: mods = [] elif self._max_depth == 1: @@ -234,31 +234,31 @@ def named_modules(self) -> List[Tuple[str, nn.Module]]: return mods @property - def layer_names(self) -> List[str]: + def layer_names(self) -> list[str]: return list(self._layer_summary.keys()) @property - def layer_types(self) -> List[str]: + def layer_types(self) -> list[str]: return [layer.layer_type for layer in self._layer_summary.values()] @property - def in_sizes(self) -> List: + def in_sizes(self) -> list: return [layer.in_size for layer in self._layer_summary.values()] @property - def out_sizes(self) -> List: + def out_sizes(self) -> list: return [layer.out_size for layer in self._layer_summary.values()] @property - def param_nums(self) -> List[int]: + def param_nums(self) -> list[int]: return [layer.num_parameters for layer in self._layer_summary.values()] @property - def training_modes(self) -> List[bool]: + def training_modes(self) -> list[bool]: return [layer.training for layer in self._layer_summary.values()] @property - def total_training_modes(self) -> Dict[str, int]: + def total_training_modes(self) -> dict[str, int]: modes = [layer.training for layer in self._model.modules()] modes = modes[1:] # exclude the root module return {"train": modes.count(True), "eval": modes.count(False)} @@ -279,7 +279,7 @@ def total_layer_params(self) -> int: def model_size(self) -> float: return self.total_parameters * self._precision_megabytes - def summarize(self) -> Dict[str, LayerSummary]: + def summarize(self) -> dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -318,7 +318,7 @@ def _forward_example_input(self) -> None: model(input_) mode.restore(model) - def _get_summary_data(self) -> List[Tuple[str, List[str]]]: + def _get_summary_data(self) -> list[tuple[str, list[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -341,7 +341,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]: return arrays - def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" layer_summaries = dict(arrays) layer_summaries[" "].append(" ") @@ -368,7 +368,7 @@ def __repr__(self) -> str: return str(self) -def parse_batch_shape(batch: Any) -> Union[str, List]: +def parse_batch_shape(batch: Any) -> Union[str, list]: if hasattr(batch, "shape"): return list(batch.shape) @@ -382,8 +382,8 @@ def _format_summary_table( total_parameters: int, trainable_parameters: int, model_size: float, - total_training_modes: Dict[str, int], - *cols: Tuple[str, List[str]], + total_training_modes: dict[str, int], + *cols: tuple[str, list[str]], ) -> str: """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big string defining the summary table that are nicely formatted.""" diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py index 57d9ae5024b58..5038aebf0db79 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py @@ -14,7 +14,6 @@ """Utilities that can be used with Deepspeed.""" from collections import OrderedDict -from typing import Dict, List, Tuple import torch from lightning_utilities.core.imports import RequirementCache @@ -54,7 +53,7 @@ def partitioned_size(p: Parameter) -> int: class DeepSpeedSummary(ModelSummary): @override - def summarize(self) -> Dict[str, DeepSpeedLayerSummary]: # type: ignore[override] + def summarize(self) -> dict[str, DeepSpeedLayerSummary]: # type: ignore[override] summary = OrderedDict((name, DeepSpeedLayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() @@ -83,11 +82,11 @@ def trainable_parameters(self) -> int: ) @property - def parameters_per_layer(self) -> List[int]: + def parameters_per_layer(self) -> list[int]: return [layer.average_shard_parameters for layer in self._layer_summary.values()] @override - def _get_summary_data(self) -> List[Tuple[str, List[str]]]: + def _get_summary_data(self) -> list[tuple[str, list[str]]]: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -112,7 +111,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]: return arrays @override - def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], total_leftover_params: int) -> None: + def _add_leftover_params_to_summary(self, arrays: list[tuple[str, list[str]]], total_leftover_params: int) -> None: """Add summary of params not associated with module or layer to model summary.""" super()._add_leftover_params_to_summary(arrays, total_leftover_params) layer_summaries = dict(arrays) diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py index 8680285c272e8..da0309b0626bb 100644 --- a/src/lightning/pytorch/utilities/parameter_tying.py +++ b/src/lightning/pytorch/utilities/parameter_tying.py @@ -18,17 +18,17 @@ """ -from typing import Dict, List, Optional +from typing import Optional from torch import nn -def find_shared_parameters(module: nn.Module) -> List[str]: +def find_shared_parameters(module: nn.Module) -> list[str]: """Returns a list of names of shared parameters set in the module.""" return _find_shared_parameters(module) -def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[Dict] = None, prefix: str = "") -> List[str]: +def _find_shared_parameters(module: nn.Module, tied_parameters: Optional[dict] = None, prefix: str = "") -> list[str]: if tied_parameters is None: tied_parameters = {} for name, param in module._parameters.items(): diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 0f4460a3d5144..16eef555291bd 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -17,8 +17,9 @@ import inspect import pickle import types +from collections.abc import MutableMapping, Sequence from dataclasses import fields, is_dataclass -from typing import Any, Dict, List, Literal, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Literal, Optional, Union from torch import nn @@ -48,7 +49,7 @@ def clean_namespace(hparams: MutableMapping) -> None: del hparams[k] -def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]]: +def parse_class_init_keys(cls: type) -> tuple[str, Optional[str], Optional[str]]: """Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: @@ -60,7 +61,7 @@ def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]] ('self', 'my_args', 'my_kwargs') """ - init_parameters = inspect.signature(cls.__init__).parameters + init_parameters = inspect.signature(cls.__init__).parameters # type: ignore[misc] # docs claims the params are always ordered # https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters init_params = list(init_parameters.values()) @@ -68,7 +69,7 @@ def parse_class_init_keys(cls: Type) -> Tuple[str, Optional[str], Optional[str]] n_self = init_params[0].name def _get_first_if_any( - params: List[inspect.Parameter], + params: list[inspect.Parameter], param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], ) -> Optional[str]: for p in params: @@ -82,13 +83,13 @@ def _get_first_if_any( return n_self, n_args, n_kwargs -def get_init_args(frame: types.FrameType) -> Dict[str, Any]: # pragma: no-cover +def get_init_args(frame: types.FrameType) -> dict[str, Any]: # pragma: no-cover """For backwards compatibility: #16369.""" _, local_args = _get_init_args(frame) return local_args -def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any]]: +def _get_init_args(frame: types.FrameType) -> tuple[Optional[Any], dict[str, Any]]: _, _, _, local_vars = inspect.getargvalues(frame) if "__class__" not in local_vars: return None, {} @@ -109,10 +110,10 @@ def _get_init_args(frame: types.FrameType) -> Tuple[Optional[Any], Dict[str, Any def collect_init_args( frame: types.FrameType, - path_args: List[Dict[str, Any]], + path_args: list[dict[str, Any]], inside: bool = False, - classes: Tuple[Type, ...] = (), -) -> List[Dict[str, Any]]: + classes: tuple[type, ...] = (), +) -> list[dict[str, Any]]: """Recursively collects the arguments passed to the child constructors in the inheritance tree. Args: @@ -147,7 +148,7 @@ def save_hyperparameters( *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, - given_hparams: Optional[Dict[str, Any]] = None, + given_hparams: Optional[dict[str, Any]] = None, ) -> None: """See :meth:`~lightning.pytorch.LightningModule.save_hyperparameters`""" @@ -232,14 +233,14 @@ class AttributeDict(_AttributeDict): """ -def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> List[Any]: +def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> list[Any]: """Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ - holders: List[Any] = [] + holders: list[Any] = [] # Check if attribute in model if hasattr(model, attribute): diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 4ba9e7f0f960f..7250ba59366c2 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -13,8 +13,8 @@ # limitations under the License. """Utilities to help with reproducibility of models.""" +from collections.abc import Generator from contextlib import contextmanager -from typing import Generator from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 03b3afd61b875..9c46913681143 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple +from typing import Optional from lightning_utilities.core.imports import RequirementCache @@ -42,7 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, -) -> Tuple[List[str], Dict[str, bool]]: +) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. Args: diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index c1b971e924a52..8fccfa79c976a 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -17,19 +17,13 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from collections.abc import Generator, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, - Generator, - Iterator, - List, - Mapping, Optional, Protocol, - Sequence, - Tuple, - Type, TypedDict, Union, runtime_checkable, @@ -47,8 +41,8 @@ _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] -_EVALUATE_OUTPUT = List[Mapping[str, float]] # 1 dict per DataLoader -_PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] +_EVALUATE_OUTPUT = list[Mapping[str, float]] # 1 dict per DataLoader +_PREDICT_OUTPUT = Union[list[Any], list[list[Any]]] TRAIN_DATALOADERS = Any # any iterable or collection of iterables EVAL_DATALOADERS = Any # any iterable or collection of iterables @@ -60,7 +54,7 @@ class DistributedDataParallel(Protocol): def __init__( self, module: torch.nn.Module, - device_ids: Optional[List[Union[int, torch.device]]] = None, + device_ids: Optional[list[Union[int, torch.device]]] = None, output_device: Optional[Union[int, torch.device]] = None, dim: int = 0, broadcast_buffers: bool = True, @@ -79,7 +73,7 @@ def no_sync(self) -> Generator: ... # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau) LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]] LRSchedulerPLType = Union[LRScheduler, ReduceLROnPlateau] @@ -119,7 +113,7 @@ class OptimizerLRSchedulerConfig(TypedDict): Union[ Optimizer, Sequence[Optimizer], - Tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], + tuple[Sequence[Optimizer], Sequence[Union[LRSchedulerTypeUnion, LRSchedulerConfig]]], OptimizerLRSchedulerConfig, Sequence[OptimizerLRSchedulerConfig], ] diff --git a/src/lightning/pytorch/utilities/upgrade_checkpoint.py b/src/lightning/pytorch/utilities/upgrade_checkpoint.py index 87ad6031f9f24..04cf000283d77 100644 --- a/src/lightning/pytorch/utilities/upgrade_checkpoint.py +++ b/src/lightning/pytorch/utilities/upgrade_checkpoint.py @@ -16,7 +16,6 @@ from argparse import ArgumentParser, Namespace from pathlib import Path from shutil import copyfile -from typing import List import torch from tqdm import tqdm @@ -29,7 +28,7 @@ def _upgrade(args: Namespace) -> None: path = Path(args.path).absolute() extension: str = args.extension if args.extension.startswith(".") else f".{args.extension}" - files: List[Path] = [] + files: list[Path] = [] if not path.exists(): _log.error( diff --git a/src/lightning_fabric/__setup__.py b/src/lightning_fabric/__setup__.py index 8fe0bc0937ef5..a55e1f2332f37 100644 --- a/src/lightning_fabric/__setup__.py +++ b/src/lightning_fabric/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any, Dict +from typing import Any from pkg_resources import parse_requirements from setuptools import find_packages @@ -29,7 +29,7 @@ def _load_assistant() -> ModuleType: return _load_py_module("assistant", location) -def _prepare_extras() -> Dict[str, Any]: +def _prepare_extras() -> dict[str, Any]: assistant = _load_assistant() # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -49,7 +49,7 @@ def _prepare_extras() -> Dict[str, Any]: return extras -def _setup_args() -> Dict[str, Any]: +def _setup_args() -> dict[str, Any]: assistant = _load_assistant() about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) @@ -73,7 +73,7 @@ def _setup_args() -> Dict[str, Any]: "include_package_data": True, "zip_safe": False, "keywords": ["deep learning", "pytorch", "AI"], - "python_requires": ">=3.8", + "python_requires": ">=3.9", "setup_requires": ["wheel"], "install_requires": assistant.load_requirements( _PATH_REQUIREMENTS, unfreeze="none" if _FREEZE_REQUIREMENTS else "all" @@ -105,7 +105,6 @@ def _setup_args() -> Dict[str, Any]: # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/src/pytorch_lightning/__setup__.py b/src/pytorch_lightning/__setup__.py index 7eedace6cac93..6677b469ba1de 100644 --- a/src/pytorch_lightning/__setup__.py +++ b/src/pytorch_lightning/__setup__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from types import ModuleType -from typing import Any, Dict +from typing import Any from pkg_resources import parse_requirements from setuptools import find_packages @@ -29,7 +29,7 @@ def _load_assistant() -> ModuleType: return _load_py_module("assistant", location) -def _prepare_extras() -> Dict[str, Any]: +def _prepare_extras() -> dict[str, Any]: assistant = _load_assistant() # https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-extras # Define package extras. These are only installed if you specify them. @@ -49,7 +49,7 @@ def _prepare_extras() -> Dict[str, Any]: return extras -def _setup_args() -> Dict[str, Any]: +def _setup_args() -> dict[str, Any]: assistant = _load_assistant() about = _load_py_module("about", os.path.join(_PACKAGE_ROOT, "__about__.py")) version = _load_py_module("version", os.path.join(_PACKAGE_ROOT, "__version__.py")) @@ -80,7 +80,7 @@ def _setup_args() -> Dict[str, Any]: "long_description_content_type": "text/markdown", "zip_safe": False, "keywords": ["deep learning", "pytorch", "AI"], - "python_requires": ">=3.8", + "python_requires": ">=3.9", "setup_requires": ["wheel"], # TODO: aggregate pytorch and lite requirements as we include its source code directly in this package. # this is not a problem yet because lite's base requirements are all included in pytorch's base requirements @@ -107,7 +107,6 @@ def _setup_args() -> Dict[str, Any]: "Operating System :: OS Independent", # Specify the Python versions you support here. "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index e323ada908cd1..0aed3675d93e1 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -121,27 +121,32 @@ def test_tf32_message(_, __, ___, caplog, monkeypatch): def test_find_usable_cuda_devices_error_handling(): """Test error handling for edge cases when using `find_usable_cuda_devices`.""" # Asking for GPUs if no GPUs visible - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), pytest.raises( - ValueError, match="You requested to find 2 devices but there are no visible CUDA" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0), + pytest.raises(ValueError, match="You requested to find 2 devices but there are no visible CUDA"), ): find_usable_cuda_devices(2) # Asking for more GPUs than are visible - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), pytest.raises( - ValueError, match="this machine only has 1 GPUs" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1), + pytest.raises(ValueError, match="this machine only has 1 GPUs"), ): find_usable_cuda_devices(2) # All GPUs are unusable tensor_mock = Mock(side_effect=RuntimeError) # simulate device placement fails - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), mock.patch( - "lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock - ), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")): + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2), + mock.patch("lightning.fabric.accelerators.cuda.torch.tensor", tensor_mock), + pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")), + ): find_usable_cuda_devices(2) # Request for as many GPUs as there are, no error should be raised - with mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), mock.patch( - "lightning.fabric.accelerators.cuda.torch.tensor" + with ( + mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=5), + mock.patch("lightning.fabric.accelerators.cuda.torch.tensor"), ): assert find_usable_cuda_devices(-1) == [0, 1, 2, 3, 4] diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index e8f39b6e83406..2544df1e01ff8 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator @@ -30,7 +30,7 @@ def __init__(self, param1, param2): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 446994167d0a1..889decd517b12 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -14,8 +14,8 @@ import os import sys import threading +from concurrent.futures.process import _ExecutorManagerThread from pathlib import Path -from typing import List from unittest.mock import Mock import lightning.fabric @@ -25,9 +25,6 @@ from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection -if sys.version_info >= (3, 9): - from concurrent.futures.process import _ExecutorManagerThread - @pytest.fixture(autouse=True) def preserve_global_rank_variable(): @@ -200,7 +197,7 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: +def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: """An adaptation of `tests/tests_pytorch/conftest.py::pytest_collection_modifyitems`""" initial_size = len(items) conditions = [] diff --git a/tests/tests_fabric/helpers/datasets.py b/tests/tests_fabric/helpers/datasets.py index 211e1f36a9ab5..ee14b21dc546c 100644 --- a/tests/tests_fabric/helpers/datasets.py +++ b/tests/tests_fabric/helpers/datasets.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import torch from torch import Tensor diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index c8deb9d335161..b4c223e770282 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -29,13 +29,16 @@ @contextlib.contextmanager def check_destroy_group(): - with mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", - wraps=TorchCollective.new_group, - ) as mock_new, mock.patch( - "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", - wraps=TorchCollective.destroy_group, - ) as mock_destroy: + with ( + mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.new_group", + wraps=TorchCollective.new_group, + ) as mock_new, + mock.patch( + "lightning.fabric.plugins.collectives.torch_collective.TorchCollective.destroy_group", + wraps=TorchCollective.destroy_group, + ) as mock_destroy, + ): yield # 0 to account for tests that mock distributed # -1 to account for destroying the default process group @@ -155,9 +158,10 @@ def test_repeated_create_and_destroy(): with pytest.raises(RuntimeError, match="TorchCollective` already owns a group"): collective.create_group() - with mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), mock.patch( - "torch.distributed.destroy_process_group" - ) as destroy_mock: + with ( + mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {collective.group: ("", None)}), + mock.patch("torch.distributed.destroy_process_group") as destroy_mock, + ): collective.teardown() # this would be called twice if `init_process_group` wasn't patched. once for the group and once for the default # group @@ -300,9 +304,11 @@ def test_collective_manages_default_group(): assert TorchCollective.manages_default_group - with mock.patch.object(collective, "_group") as mock_group, mock.patch.dict( - "torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)} - ), mock.patch("torch.distributed.destroy_process_group") as destroy_mock: + with ( + mock.patch.object(collective, "_group") as mock_group, + mock.patch.dict("torch.distributed.distributed_c10d._pg_map", {mock_group: ("", None)}), + mock.patch("torch.distributed.destroy_process_group") as destroy_mock, + ): collective.teardown() destroy_mock.assert_called_once_with(mock_group) diff --git a/tests/tests_fabric/plugins/environments/test_lsf.py b/tests/tests_fabric/plugins/environments/test_lsf.py index b444f6fc4d781..4e60d968dc953 100644 --- a/tests/tests_fabric/plugins/environments/test_lsf.py +++ b/tests/tests_fabric/plugins/environments/test_lsf.py @@ -41,8 +41,9 @@ def test_empty_lsb_djob_rankfile(): def test_missing_lsb_job_id(tmp_path): """Test an error when the job id cannot be found.""" - with mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises( - ValueError, match="Could not find job id in environment variable LSB_JOBID" + with ( + mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), + pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"), ): LSFEnvironment() diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index 73457ede41298..f237478a533f4 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -155,8 +155,9 @@ def test_srun_variable_validation(): """Test that we raise useful errors when `srun` variables are misconfigured.""" with mock.patch.dict(os.environ, {"SLURM_NTASKS": "1"}): SLURMEnvironment() - with mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), pytest.raises( - RuntimeError, match="You set `--ntasks=2` in your SLURM" + with ( + mock.patch.dict(os.environ, {"SLURM_NTASKS": "2"}), + pytest.raises(RuntimeError, match="You set `--ntasks=2` in your SLURM"), ): SLURMEnvironment() diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index b63d443098eac..6c595fba7acab 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -100,8 +100,11 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with mock.patch( - "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), pytest.raises(RuntimeError, match="requires that your script guards the main"): + with ( + mock.patch( + "lightning.fabric.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), + pytest.raises(RuntimeError, match="requires that your script guards the main"), + ): launcher.launch(function=Mock()) diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index 56d9875dfefed..b98d5f8226dc2 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -58,9 +58,12 @@ def test_ddp_no_backward_sync(): strategy = DDPStrategy() assert isinstance(strategy._backward_sync_control, _DDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `DistributedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(Mock(), True), + ): pass module = MagicMock(spec=DistributedDataParallel) diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 3be535effa078..4811599ed05ab 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -404,9 +404,11 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init): fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision="bf16-true") fabric.launch() - with mock.patch("deepspeed.zero.Init") as zero_init_mock, mock.patch( - "torch.Tensor.uniform_" - ) as init_mock, fabric.init_module(empty_init=empty_init): + with ( + mock.patch("deepspeed.zero.Init") as zero_init_mock, + mock.patch("torch.Tensor.uniform_") as init_mock, + fabric.init_module(empty_init=empty_init), + ): model = BoringModel() zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY) diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index 0c46e7ac1763c..cb6542cdb6243 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -133,9 +133,12 @@ def test_no_backward_sync(): strategy = FSDPStrategy() assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(Mock(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(Mock(), True), + ): pass module = MagicMock(spec=FullyShardedDataParallel) @@ -172,9 +175,12 @@ def __init__(self): assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) strategy._parallel_devices = [torch.device("cuda", 0)] - with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock: + with ( + mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), + mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock, + ): wrapped = strategy.setup_module(Model()) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index e2864b684c4a7..879a55cf77f34 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -48,9 +48,12 @@ def test_xla_fsdp_no_backward_sync(): strategy = XLAFSDPStrategy() assert isinstance(strategy._backward_sync_control, _XLAFSDPBackwardSyncControl) - with pytest.raises( - TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" - ), strategy._backward_sync_control.no_backward_sync(object(), True): + with ( + pytest.raises( + TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" + ), + strategy._backward_sync_control.no_backward_sync(object(), True), + ): pass module = MagicMock(spec=XlaFullyShardedDataParallel) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 08d6dbb45ed91..8a6e9206b3df5 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -15,7 +15,7 @@ import os import sys from contextlib import nullcontext -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import Mock @@ -165,7 +165,7 @@ class Accel(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: @@ -960,28 +960,33 @@ def test_arguments_from_environment_collision(): with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu"}): _Connector(accelerator="cuda") - with mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`" + with ( + mock.patch.dict(os.environ, {"LT_ACCELERATOR": "cpu", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(accelerator='cuda', ...\\)` but .* `--accelerator=cpu`"), ): _Connector(accelerator="cuda") - with mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`" + with ( + mock.patch.dict(os.environ, {"LT_STRATEGY": "ddp", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(strategy='ddp_spawn', ...\\)` but .* `--strategy=ddp`"), ): _Connector(strategy="ddp_spawn") - with mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`" + with ( + mock.patch.dict(os.environ, {"LT_DEVICES": "2", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(devices=3, ...\\)` but .* `--devices=2`"), ): _Connector(devices=3) - with mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`" + with ( + mock.patch.dict(os.environ, {"LT_NUM_NODES": "3", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"), ): _Connector(num_nodes=2) - with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), pytest.raises( - ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`" + with ( + mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed", "LT_CLI_USED": "1"}), + pytest.raises(ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"), ): _Connector(precision="64-true") diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 70d04d5431404..7bb6b29eceaf2 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -746,9 +746,10 @@ def test_no_backward_sync(): # pretend that the strategy does not support skipping backward sync fabric._strategy = Mock(spec=ParallelStrategy, _backward_sync_control=None) - with pytest.warns( - PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the" - ), fabric.no_backward_sync(model): + with ( + pytest.warns(PossibleUserWarning, match="The `ParallelStrategy` does not support skipping the"), + fabric.no_backward_sync(model), + ): pass # for single-device strategies, it becomes a no-op without warning diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py index 216b77e6b9299..2584aab8bdc2e 100644 --- a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -41,8 +41,9 @@ def test_parse_cli_args(args, expected): def test_process_cli_args(tmp_path, caplog, monkeypatch): # PyTorch version < 2.3 monkeypatch.setattr(lightning.fabric.utilities.consolidate_checkpoint, "_TORCH_GREATER_EQUAL_2_3", False) - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace()) assert "requires PyTorch >= 2.3." in caplog.text @@ -51,8 +52,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint does not exist checkpoint_folder = Path("does/not/exist") - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=checkpoint_folder)) assert f"checkpoint folder does not exist: {checkpoint_folder}" in caplog.text @@ -61,8 +63,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not a folder file = tmp_path / "checkpoint_file" file.touch() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=file)) assert "checkpoint path must be a folder" in caplog.text @@ -71,8 +74,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint exists but is not an FSDP checkpoint folder = tmp_path / "checkpoint_folder" folder.mkdir() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=folder)) assert "Only FSDP-sharded checkpoints saved with Lightning are supported" in caplog.text @@ -89,8 +93,9 @@ def test_process_cli_args(tmp_path, caplog, monkeypatch): # Checkpoint is a FSDP folder, output file already exists file = tmp_path / "ouput_file" file.touch() - with caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), pytest.raises( - SystemExit + with ( + caplog.at_level(logging.ERROR, logger="lightning.fabric.utilities.consolidate_checkpoint"), + pytest.raises(SystemExit), ): _process_cli_args(Namespace(checkpoint_folder=folder, output_file=file)) assert "path for the converted checkpoint already exists" in caplog.text diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index cc6c23bddbd7b..f5a78a1529a52 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -215,9 +215,10 @@ def test_infinite_barrier(): # distributed available barrier = _InfiniteBarrier() - with mock.patch( - "lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True - ), mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock: + with ( + mock.patch("lightning.fabric.utilities.distributed._distributed_is_initialized", return_value=True), + mock.patch("lightning.fabric.utilities.distributed.torch.distributed") as dist_mock, + ): barrier.__enter__() dist_mock.new_group.assert_called_once() assert barrier.barrier == barrier.group.monitored_barrier diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index eefadb285af02..d410d0766d97b 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -39,8 +39,9 @@ def test_get_available_flops(xla_available): with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None - with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch( - "torch.cuda.get_device_name", return_value="t4" + with ( + pytest.warns(match="t4' does not support torch.bfloat"), + mock.patch("torch.cuda.get_device_name", return_value="t4"), ): assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 7654125af3e44..6967bffd9ffa2 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any import torch from lightning.pytorch import Trainer @@ -24,7 +24,7 @@ class TestAccelerator(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index cd34fe3f2b318..844556064621d 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Union from unittest.mock import Mock import lightning.pytorch as pl @@ -53,7 +53,7 @@ def setup(self, trainer: "pl.Trainer") -> None: def restore_checkpoint_after_setup(self) -> bool: return restore_after_pre_setup - def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> dict[str, Any]: assert self.setup_called == restore_after_pre_setup return super().load_checkpoint(checkpoint_path) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index b8d3d6d36c075..89c1effe839a8 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -141,9 +141,12 @@ def on_train_start(self) -> None: model = TestModel() - with mock.patch( - "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True - ) as mock_progress_stop, pytest.raises(SystemExit): + with ( + mock.patch( + "lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True + ) as mock_progress_stop, + pytest.raises(SystemExit), + ): progress_bar = RichProgressBar() trainer = Trainer( default_root_dir=tmp_path, diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index aacee958faa45..f1d999f1df61a 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -14,7 +14,7 @@ import csv import os import re -from typing import Dict, Optional +from typing import Optional from unittest import mock from unittest.mock import Mock @@ -40,7 +40,7 @@ def test_device_stats_gpu_from_torch(tmp_path): class DebugLogger(CSVLogger): @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: fields = [ "allocated_bytes.all.freed", "inactive_split.all.peak", @@ -74,7 +74,7 @@ def test_device_stats_cpu(cpu_stats_mock, tmp_path, cpu_stats): CPU_METRIC_KEYS = (_CPU_VM_PERCENT, _CPU_SWAP_PERCENT, _CPU_PERCENT) class DebugLogger(CSVLogger): - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: enabled = cpu_stats is not False for f in CPU_METRIC_KEYS: has_cpu_metrics = any(f in h for h in metrics) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 75f331a9401c7..a3d56bb0135c3 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -15,7 +15,7 @@ import math import os import pickle -from typing import List, Optional +from typing import Optional from unittest import mock from unittest.mock import Mock @@ -407,7 +407,7 @@ def on_train_end(self) -> None: ) def test_multiple_early_stopping_callbacks( tmp_path, - callbacks: List[EarlyStopping], + callbacks: list[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, strategy: str, diff --git a/tests/tests_pytorch/callbacks/test_model_summary.py b/tests/tests_pytorch/callbacks/test_model_summary.py index b42907dc9a38d..215176ee2376b 100644 --- a/tests/tests_pytorch/callbacks/test_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_model_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Any from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelSummary @@ -45,7 +45,7 @@ def test_custom_model_summary_callback_summarize(tmp_path): class CustomModelSummary(ModelSummary): @staticmethod def summarize( - summary_data: List[Tuple[str, List[str]]], + summary_data: list[tuple[str, list[str]]], total_parameters: int, trainable_parameters: int, model_size: float, diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index d57ac76f04400..8d3a1800e1fa2 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -13,8 +13,9 @@ # limitations under the License. import logging import os +from contextlib import AbstractContextManager from pathlib import Path -from typing import ContextManager, Optional +from typing import Optional from unittest import mock import pytest @@ -382,5 +383,5 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str): trainer.fit(model) -def _backward_patch(trainer: Trainer) -> ContextManager: +def _backward_patch(trainer: Trainer) -> AbstractContextManager: return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index a74efba75813b..4867134a85642 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -43,8 +43,9 @@ def test_throughput_monitor_fit(tmp_path): ) # these timing results are meant to precisely match the `test_throughput_monitor` test in fabric timings = [0.0] + [0.5 + i for i in range(1, 6)] - with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( - "time.perf_counter", side_effect=timings + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), ): trainer.fit(model) @@ -179,8 +180,9 @@ def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_pat enable_progress_bar=False, ) timings = [0.0] + [0.5 + i for i in range(1, 11)] - with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch( - "time.perf_counter", side_effect=timings + with ( + mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), + mock.patch("time.perf_counter", side_effect=timings), ): trainer.fit(model) diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index 3ae7d6be4995e..c07400eaf8446 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -92,9 +92,10 @@ def test_trainer_save_checkpoint_storage_options(tmp_path, xla_available): io_mock.assert_called_with(ANY, instance_path, storage_options=None) checkpoint_mock = Mock() - with mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, mock.patch.object( - trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock - ) as dump_mock: + with ( + mock.patch.object(trainer.strategy, "save_checkpoint") as save_mock, + mock.patch.object(trainer._checkpoint_connector, "dump_checkpoint", return_value=checkpoint_mock) as dump_mock, + ): trainer.save_checkpoint(instance_path, True) dump_mock.assert_called_with(True) save_mock.assert_called_with(checkpoint_mock, instance_path, storage_options=None) diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 78e81c7c5fa26..ea5207516cad1 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -15,10 +15,10 @@ import signal import sys import threading +from concurrent.futures.process import _ExecutorManagerThread from functools import partial from http.server import SimpleHTTPRequestHandler from pathlib import Path -from typing import List from unittest.mock import Mock import lightning.fabric @@ -35,9 +35,6 @@ from tests_pytorch import _PATH_DATASETS -if sys.version_info >= (3, 9): - from concurrent.futures.process import _ExecutorManagerThread - @pytest.fixture(scope="session") def datadir(): @@ -323,7 +320,7 @@ def leave_no_artifacts_behind(): assert not difference, f"Test left artifacts behind: {difference}" -def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None: +def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.Config) -> None: initial_size = len(items) conditions = [] filtered, skipped = 0, 0 diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 65fccb691a33d..5f468156be716 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -14,7 +14,7 @@ import pickle from argparse import Namespace from dataclasses import dataclass -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import Mock, PropertyMock, call @@ -187,10 +187,10 @@ def validation_step(self, batch, batch_idx): return out class CustomBoringDataModule(BoringDataModule): - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: return {"my": "state_dict"} - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.my_state_dict = state_dict dm = CustomBoringDataModule() diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index 8ab6eca907ce6..b25b7ae648a3a 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -110,9 +110,10 @@ def configure_optimizers(self): default_root_dir=tmp_path, limit_train_batches=8, limit_val_batches=1, max_epochs=1, enable_model_summary=False ) - with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, patch.multiple( - torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT - ) as adam: + with ( + patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, + patch.multiple(torch.optim.Adam, zero_grad=DEFAULT, step=DEFAULT) as adam, + ): trainer.fit(model) assert sgd["step"].call_count == 4 diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 004d979fd1b18..dcb3f71c7499c 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -625,8 +625,9 @@ def test_logger_sync_dist(distributed_env, log_val): else nullcontext() ) - with warning_ctx( - PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`" - ), patch_ctx: + with ( + warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"), + patch_ctx, + ): value = _ResultCollection._get_cache(result_metric, on_step=False) assert value == 0.5 diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index 9b1d4ec7353cb..014fb374e5d5e 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -16,7 +16,8 @@ import random import time import urllib.request -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Optional import torch from torch import Tensor @@ -63,7 +64,7 @@ def __init__( data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME self.data, self.targets = self._try_load(os.path.join(self.cached_folder_path, data_file)) - def __getitem__(self, idx: int) -> Tuple[Tensor, int]: + def __getitem__(self, idx: int) -> tuple[Tensor, int]: img = self.data[idx].float().unsqueeze(0) target = int(self.targets[idx]) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index 4d74e046c590f..dcdd504fd4660 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -14,7 +14,7 @@ import pickle from argparse import Namespace from copy import deepcopy -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import patch import numpy as np @@ -252,12 +252,12 @@ def __init__(self, param_one, param_two): @patch("lightning.pytorch.loggers.tensorboard.TensorBoardLogger.log_hyperparams") def test_log_hyperparams_key_collision(_, tmp_path): class TestModel(BoringModel): - def __init__(self, hparams: Dict[str, Any]) -> None: + def __init__(self, hparams: dict[str, Any]) -> None: super().__init__() self.save_hyperparameters(hparams) class TestDataModule(BoringDataModule): - def __init__(self, hparams: Dict[str, Any]) -> None: + def __init__(self, hparams: dict[str, Any]) -> None: super().__init__() self.save_hyperparameters(hparams) diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index 0ea6290586f55..2fb04d0d9d8d1 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator, Mapping from contextlib import nullcontext -from typing import Dict, Generic, Iterator, Mapping, TypeVar +from typing import Generic, TypeVar import pytest import torch @@ -49,8 +50,8 @@ def test_closure_result_apply_accumulation(): class OutputMapping(Generic[T], Mapping[str, T]): - def __init__(self, d: Dict[str, T]) -> None: - self.d: Dict[str, T] = d + def __init__(self, d: dict[str, T]) -> None: + self.d: dict[str, T] = d def __iter__(self) -> Iterator[str]: return iter(self.d) diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 763a6ded14447..75b25e3d98fd8 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import Counter -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any import pytest import torch diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 8d94275b5b245..1820ca3568173 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections.abc import Iterator from copy import deepcopy from dataclasses import dataclass -from typing import Any, Dict, Iterator +from typing import Any from unittest.mock import ANY, Mock import pytest @@ -87,10 +88,10 @@ def advance(self) -> None: self.outputs.append(value) - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {"iteration_count": self.iteration_count, "outputs": self.outputs} - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: self.iteration_count = state_dict["iteration_count"] self.outputs = state_dict["outputs"] @@ -140,10 +141,10 @@ def advance(self) -> None: return loop.run() - def on_save_checkpoint(self) -> Dict: + def on_save_checkpoint(self) -> dict: return {"a": self.a} - def on_load_checkpoint(self, state_dict: Dict) -> None: + def on_load_checkpoint(self, state_dict: dict) -> None: self.a = state_dict["a"] trainer = Trainer() diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index fe7e3fbbab357..64f70b176a971 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -15,8 +15,9 @@ import logging as log import os import pickle +from collections.abc import Mapping from copy import deepcopy -from typing import Generic, Mapping, TypeVar +from typing import Generic, TypeVar import cloudpickle import pytest diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 29eb6d6d6d511..3e2fba54bcd03 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from collections.abc import Iterable import pytest import torch diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 185a767d9e8c9..58baa47e7a620 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import MagicMock, Mock import torch @@ -27,10 +27,10 @@ class CustomCheckpointIO(CheckpointIO): - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: torch.save(checkpoint, path) - def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> dict[str, Any]: return torch.load(path, weights_only=True) def remove_checkpoint(self, path: _PATH) -> None: diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index 7c883c419ea82..ec4dd8825c8ea 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,5 +1,3 @@ -from typing import Dict - import pytest import torch from lightning.pytorch import Trainer @@ -21,7 +19,7 @@ def serialize(x): return {"x": deserialize}, {"output": serialize} - def serve_step(self, x: Tensor) -> Dict[str, Tensor]: + def serve_step(self, x: Tensor) -> dict[str, Tensor]: assert torch.equal(x, torch.arange(32, dtype=torch.float)) return {"output": torch.tensor([0, 1])} diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index b0462c0105a9f..394d827058987 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -209,10 +209,13 @@ def test_memory_sharing_disabled(tmp_path): def test_check_for_missing_main_guard(): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn") - with mock.patch( - "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", - return_value=Mock(_inheriting=True), # pretend that main is importing itself - ), pytest.raises(RuntimeError, match="requires that your script guards the main"): + with ( + mock.patch( + "lightning.pytorch.strategies.launchers.multiprocessing.mp.current_process", + return_value=Mock(_inheriting=True), # pretend that main is importing itself + ), + pytest.raises(RuntimeError, match="requires that your script guards the main"), + ): launcher.launch(function=Mock()) diff --git a/tests/tests_pytorch/strategies/test_custom_strategy.py b/tests/tests_pytorch/strategies/test_custom_strategy.py index 7f7d018f634e1..347dacbd9a811 100644 --- a/tests/tests_pytorch/strategies/test_custom_strategy.py +++ b/tests/tests_pytorch/strategies/test_custom_strategy.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import pytest import torch diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index be9428ff7533c..73697ea131545 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -15,7 +15,7 @@ import json import os from re import escape -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import ANY, Mock @@ -48,7 +48,7 @@ def configure_model(self) -> None: if self.layer is None: self.layer = torch.nn.Linear(32, 2) - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: self.configure_model() @@ -73,7 +73,7 @@ def configure_model(self) -> None: if self.layer is None: self.layer = torch.nn.Linear(32, 2) - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: self.configure_model() @property @@ -623,7 +623,7 @@ def configure_optimizers(self): lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: if not hasattr(self, "model"): self.configure_model() diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index aec01b83e956a..2aee68f7ae733 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -444,9 +444,12 @@ def __init__(self): strategy._parallel_devices = [torch.device("cuda", 0)] strategy._lightning_module = model strategy._process_group = Mock() - with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( - "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" - ) as apply_mock: + with ( + mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), + mock.patch( + "torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing" + ) as apply_mock, + ): wrapped = strategy._setup_model(model) apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index cdec778afbfb8..4fc836c764833 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -20,7 +20,7 @@ from contextlib import ExitStack, contextmanager, redirect_stdout from io import StringIO from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union from unittest import mock from unittest.mock import ANY @@ -127,7 +127,7 @@ def _model_builder(model_param: int) -> Model: def _trainer_builder( - limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None + limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[list[Callback], Callback]] = None ) -> Trainer: return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) @@ -409,8 +409,9 @@ def test_lightning_cli_config_and_subclass_mode(cleandir): with open(config_path, "w") as f: f.write(yaml.dump(input_config)) - with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses( - LightningDataModule, DataDirDataModule + with ( + mock.patch("sys.argv", ["any.py", "--config", config_path]), + mock_subclasses(LightningDataModule, DataDirDataModule), ): cli = LightningCLI( BoringModel, @@ -461,9 +462,12 @@ def test_lightning_cli_help(): cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"] out = StringIO() - with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses( - LightningDataModule, DataDirDataModule - ), pytest.raises(SystemExit): + with ( + mock.patch("sys.argv", cli_args), + redirect_stdout(out), + mock_subclasses(LightningDataModule, DataDirDataModule), + pytest.raises(SystemExit), + ): any_model_any_data_cli() assert ("--data.data_dir" in out.getvalue()) or ("--data.init_args.data_dir" in out.getvalue()) @@ -522,7 +526,7 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) def test_lightning_cli_torch_modules(cleandir): class TestModule(BoringModel): - def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): + def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[torch.nn.Module]] = None): super().__init__() self.activation = activation self.transform = transform @@ -609,8 +613,9 @@ def on_fit_start(self): def test_cli_distributed_save_config_callback(cleandir, logger, strategy): from torch.multiprocessing import ProcessRaisedException - with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( - (MisconfigurationException, ProcessRaisedException), match=r"Error on fit start" + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + pytest.raises((MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"), ): LightningCLI( EarlyExitTestModel, @@ -710,12 +715,14 @@ def train_dataloader(self): ... from lightning.pytorch.trainer.configuration_validator import __verify_train_val_loop_configuration - with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch( - "lightning.pytorch.Trainer._run_stage" - ) as run, mock.patch( - "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", - wraps=__verify_train_val_loop_configuration, - ) as verify: + with ( + mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), + mock.patch("lightning.pytorch.Trainer._run_stage") as run, + mock.patch( + "lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration", + wraps=__verify_train_val_loop_configuration, + ) as verify, + ): cli = LightningCLI(BoringModel) run.assert_called_once() verify.assert_called_once_with(cli.trainer, cli.model) @@ -1088,15 +1095,18 @@ def __init__(self, foo, bar=5): @_xfail_python_ge_3_11_9 def test_lightning_cli_model_short_arguments(): - with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningModule, BoringModel, TestModel): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningModule, BoringModel, TestModel), + ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY) - with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses( - LightningModule, BoringModel, TestModel + with ( + mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), + mock_subclasses(LightningModule, BoringModel, TestModel), ): cli = LightningCLI(run=False) assert isinstance(cli.model, TestModel) @@ -1114,15 +1124,18 @@ def __init__(self, foo, bar=5): @_xfail_python_ge_3_11_9 def test_lightning_cli_datamodule_short_arguments(): # with set model - with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningDataModule, BoringDataModule): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningDataModule, BoringDataModule), + ): cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY) - with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses( - LightningDataModule, MyDataModule + with ( + mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), + mock_subclasses(LightningDataModule, MyDataModule), ): cli = LightningCLI(BoringModel, run=False) assert isinstance(cli.datamodule, MyDataModule) @@ -1130,17 +1143,22 @@ def test_lightning_cli_datamodule_short_arguments(): assert cli.datamodule.bar == 5 # with configurable model - with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch( - "lightning.pytorch.Trainer._fit_impl" - ) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule): + with ( + mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), + mock.patch("lightning.pytorch.Trainer._fit_impl") as run, + mock_subclasses(LightningModule, BoringModel), + mock_subclasses(LightningDataModule, BoringDataModule), + ): cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, BoringDataModule) run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY) - with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses( - LightningModule, BoringModel - ), mock_subclasses(LightningDataModule, MyDataModule): + with ( + mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), + mock_subclasses(LightningModule, BoringModel), + mock_subclasses(LightningDataModule, MyDataModule), + ): cli = LightningCLI(run=False) assert isinstance(cli.model, BoringModel) assert isinstance(cli.datamodule, MyDataModule) @@ -1307,9 +1325,10 @@ def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): def test_lightning_cli_config_with_subcommand(): config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1322,9 +1341,10 @@ def test_lightning_cli_config_before_subcommand(): "test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}, } - with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") @@ -1334,9 +1354,10 @@ def test_lightning_cli_config_before_subcommand(): assert save_config_callback.config.trainer.limit_test_batches == 1 assert save_config_callback.parser.subcommand == "test" - with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1351,17 +1372,19 @@ def test_lightning_cli_config_before_subcommand_two_configs(): config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}} config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} - with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 - with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") @@ -1370,9 +1393,10 @@ def test_lightning_cli_config_before_subcommand_two_configs(): def test_lightning_cli_config_after_subcommand(): config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"} - with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") @@ -1382,9 +1406,10 @@ def test_lightning_cli_config_after_subcommand(): def test_lightning_cli_config_before_and_after_subcommand(): config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"} - with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch( - "lightning.pytorch.Trainer.test", autospec=True - ) as test_mock: + with ( + mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), + mock.patch("lightning.pytorch.Trainer.test", autospec=True) as test_mock, + ): cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") @@ -1406,17 +1431,19 @@ def test_lightning_cli_parse_kwargs_with_subcommands(cleandir): "validate": {"default_config_files": [str(validate_config_path)]}, } - with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( - "lightning.pytorch.Trainer.fit", autospec=True - ) as fit_mock: + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, + ): cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.trainer.limit_train_batches == 2 assert cli.trainer.limit_val_batches == 1.0 - with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch( - "lightning.pytorch.Trainer.validate", autospec=True - ) as validate_mock: + with ( + mock.patch("sys.argv", ["any.py", "validate"]), + mock.patch("lightning.pytorch.Trainer.validate", autospec=True) as validate_mock, + ): cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) validate_mock.assert_called() assert cli.trainer.limit_train_batches == 1.0 @@ -1434,9 +1461,10 @@ def __init__(self, foo: int, *args, **kwargs): config_path.write_text(str(config)) parser_kwargs = {"default_config_files": [str(config_path)]} - with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( - "lightning.pytorch.Trainer.fit", autospec=True - ) as fit_mock: + with ( + mock.patch("sys.argv", ["any.py", "fit"]), + mock.patch("lightning.pytorch.Trainer.fit", autospec=True) as fit_mock, + ): cli = LightningCLI(Model, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.model.foo == 123 diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 65c5777e28fed..9e947e0723dcd 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -15,7 +15,7 @@ import os import sys from contextlib import nullcontext -from typing import Any, Dict +from typing import Any from unittest import mock from unittest.mock import Mock @@ -178,7 +178,7 @@ class Accel(Accelerator): def setup_device(self, device: torch.device) -> None: pass - def get_device_stats(self, device: torch.device) -> Dict[str, Any]: + def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index a820a3d6ee786..ca5690ed20f41 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sized from re import escape -from typing import Sized from unittest import mock from unittest.mock import Mock diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index 41d9301e4847b..af7cecdb21a08 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from unittest.mock import Mock import lightning.pytorch as pl @@ -38,7 +38,7 @@ def __init__(self): def experiment(self) -> Any: return self.exp - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None): self.logs.update(metrics) def version(self) -> Union[int, str]: @@ -144,7 +144,7 @@ def __init__(self): self.buffer = {} self.logs = {} - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: dict[str, float], step: Optional[int] = None) -> None: self.buffer.update(metrics) def finalize(self, status: str) -> None: diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index 451557d084dc7..ac660b6651be5 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -592,9 +592,10 @@ def configure_optimizers(self): limit_train_batches=limit_train_batches, limit_val_batches=0, ) - with mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, mock.patch.object( - torch.optim.lr_scheduler.StepLR, "step" - ) as mock_method_step: + with ( + mock.patch.object(CustomEpochScheduler, "step") as mock_method_epoch, + mock.patch.object(torch.optim.lr_scheduler.StepLR, "step") as mock_method_step, + ): trainer.fit(model) assert mock_method_epoch.mock_calls == [call(epoch=e) for e in range(max_epochs)] diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 8946fb4ed9481..d66f3aafee5df 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1887,8 +1887,9 @@ def training_step(self, batch, batch_idx): model = NanModel() trainer = Trainer(default_root_dir=tmp_path, detect_anomaly=True) - with pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), pytest.warns( - UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*" + with ( + pytest.raises(RuntimeError, match=r"returned nan values in its 0th output."), + pytest.warns(UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*"), ): trainer.fit(model) @@ -2067,8 +2068,9 @@ def on_fit_start(self): raise exception trainer = Trainer(default_root_dir=tmp_path) - with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress( - Exception, SystemExit + with ( + mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, + suppress(Exception, SystemExit), ): trainer.fit(ExceptionModel()) on_exception_mock.assert_called_once_with(exception) diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 74f5c1330a089..43a146c6eb089 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -13,7 +13,8 @@ # limitations under the License. import math import pickle -from typing import Any, NamedTuple, Sequence, get_args +from collections.abc import Sequence +from typing import Any, NamedTuple, get_args from unittest.mock import Mock import pytest diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 1bdac616e7b34..2ef1ecd4fe3e5 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -33,8 +33,9 @@ def test_upgrade_checkpoint_file_missing(tmp_path, caplog): # path to non-empty directory, but no checkpoints with matching extension file.touch() - with mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), caplog.at_level( - logging.ERROR + with ( + mock.patch("sys.argv", ["upgrade_checkpoint.py", str(tmp_path), "--extension", ".other"]), + caplog.at_level(logging.ERROR), ): with pytest.raises(SystemExit): upgrade_main()