Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adopt PEP 563, 585, and 604 for src/lightning/pytorch #17779

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
matplotlib>3.1, <3.6.2
omegaconf >=2.0.5, <2.4.0
hydra-core >=1.0.5, <1.4.0
jsonargparse[signatures] >=4.18.0, <4.22.0
jsonargparse[signatures] @ https://github.com/omni-us/jsonargparse/zipball/future-annotations
rich >=12.3.0, <=13.0.1
tensorboardX >=2.2, <=2.6 # min version is set by torch.onnx missing attribute
8 changes: 5 additions & 3 deletions src/lightning/pytorch/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC
from typing import Any, Dict
from typing import Any

import lightning.pytorch as pl
from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator
Expand All @@ -25,14 +27,14 @@ class Accelerator(_Accelerator, ABC):
.. warning:: Writing your own accelerator is an :ref:`experimental <versioning:Experimental API>` feature.
"""

def setup(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: pl.Trainer) -> None:
"""Setup plugins for the trainer fit and creates optimizers.

Args:
trainer: the trainer instance
"""

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""Get stats for a given device.

Args:
Expand Down
12 changes: 7 additions & 5 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Union
from __future__ import annotations

from typing import Any

import torch
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -35,20 +37,20 @@ def setup_device(self, device: torch.device) -> None:
if device.type != "cpu":
raise MisconfigurationException(f"Device should be CPU, got {device} instead.")

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""Get CPU stats from ``psutil`` package."""
return get_cpu_stats()

def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> int:
def parse_devices(devices: int | str | list[int]) -> int:
"""Accelerator device parsing logic."""
return _parse_cpu_cores(devices)

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
devices = _parse_cpu_cores(devices)
return [torch.device("cpu")] * devices
Expand Down Expand Up @@ -80,7 +82,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
_PSUTIL_AVAILABLE = RequirementCache("psutil")


def get_cpu_stats() -> Dict[str, float]:
def get_cpu_stats() -> dict[str, float]:
if not _PSUTIL_AVAILABLE:
raise ModuleNotFoundError(
f"Fetching CPU device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}"
Expand Down
14 changes: 8 additions & 6 deletions src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import logging
import os
import shutil
import subprocess
from typing import Any, Dict, List, Optional, Union
from typing import Any

import torch

Expand Down Expand Up @@ -44,7 +46,7 @@ def setup_device(self, device: torch.device) -> None:
_check_cuda_matmul_precision(device)
torch.cuda.set_device(device)

def setup(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: pl.Trainer) -> None:
# TODO refactor input from trainer to local_rank @four4fish
self.set_nvidia_flags(trainer.local_rank)
_clear_cuda_memory()
Expand All @@ -57,7 +59,7 @@ def set_nvidia_flags(local_rank: int) -> None:
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""Gets stats for the given GPU device.

Args:
Expand All @@ -76,12 +78,12 @@ def teardown(self) -> None:
_clear_cuda_memory()

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
def parse_devices(devices: int | str | list[int]) -> list[int] | None:
"""Accelerator device parsing logic."""
return _parse_gpu_ids(devices, include_cuda=True)

@staticmethod
def get_parallel_devices(devices: List[int]) -> List[torch.device]:
def get_parallel_devices(devices: list[int]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
return [torch.device("cuda", i) for i in devices]

Expand All @@ -103,7 +105,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
)


def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cover
def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover
"""Get GPU stats including memory, fan speed, and temperature from nvidia-smi.

Args:
Expand Down
12 changes: 7 additions & 5 deletions src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
from __future__ import annotations

from typing import Any

import torch

Expand Down Expand Up @@ -39,20 +41,20 @@ def setup_device(self, device: torch.device) -> None:
if device.type != "mps":
raise MisconfigurationException(f"Device should be MPS, got {device} instead.")

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""Get M1 (cpu + gpu) stats from ``psutil`` package."""
return get_device_stats()

def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices: Union[int, str, List[int]]) -> Optional[List[int]]:
def parse_devices(devices: int | str | list[int]) -> list[int] | None:
"""Accelerator device parsing logic."""
return _parse_gpu_ids(devices, include_mps=True)

@staticmethod
def get_parallel_devices(devices: Union[int, str, List[int]]) -> List[torch.device]:
def get_parallel_devices(devices: int | str | list[int]) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
parsed_devices = MPSAccelerator.parse_devices(devices)
assert parsed_devices is not None
Expand Down Expand Up @@ -84,7 +86,7 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
_SWAP_PERCENT = "M1_swap_percent"


def get_device_stats() -> Dict[str, float]:
def get_device_stats() -> dict[str, float]:
if not _PSUTIL_AVAILABLE:
raise ModuleNotFoundError(
f"Fetching MPS device stats requires `psutil` to be installed. {str(_PSUTIL_AVAILABLE)}"
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/pytorch/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from __future__ import annotations

from typing import Any

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator
Expand All @@ -25,7 +27,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator):
.. warning:: Use of this accelerator beyond import and instantiation is experimental.
"""

def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
"""Gets stats for the given XLA device.

Args:
Expand Down
16 changes: 8 additions & 8 deletions src/lightning/pytorch/callbacks/batch_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Finds optimal batch size
"""

from typing import Optional
from __future__ import annotations

import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
Expand Down Expand Up @@ -119,15 +119,15 @@ def __init__(
if mode not in self.SUPPORTED_MODES:
raise ValueError(f"`mode` should be either of {self.SUPPORTED_MODES}")

self.optimal_batch_size: Optional[int] = init_val
self.optimal_batch_size: int | None = init_val
self._mode = mode
self._steps_per_trial = steps_per_trial
self._init_val = init_val
self._max_trials = max_trials
self._batch_arg_name = batch_arg_name
self._early_exit = False

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str | None = None) -> None:
if trainer._accelerator_connector.is_distributed:
raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.")
# TODO: check if this can be enabled (#4040)
Expand Down Expand Up @@ -167,7 +167,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
" If this is not the intended behavior, please remove either one."
)

def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def scale_batch_size(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
new_size = _scale_batch_size(
trainer,
self._mode,
Expand All @@ -181,17 +181,17 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
if self._early_exit:
raise _TunerExitException()

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self.scale_batch_size(trainer, pl_module)

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if trainer.sanity_checking or trainer.state.fn != "validate":
return

self.scale_batch_size(trainer, pl_module)

def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self.scale_batch_size(trainer, pl_module)

def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
self.scale_batch_size(trainer, pl_module)
Loading