diff --git a/.actions/assistant.py b/.actions/assistant.py
index 92239a406d475..54d401b3bb120 100644
--- a/.actions/assistant.py
+++ b/.actions/assistant.py
@@ -86,6 +86,7 @@ def adjust(self, unfreeze: str) -> str:
'arrow>=1.2.0'
>>> _RequirementWithComment("arrow").adjust("major")
'arrow'
+
"""
out = str(self)
if self.strict:
@@ -115,6 +116,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen
>>> txt = '\\n'.join(txt)
>>> [r.adjust('none') for r in _parse_requirements(txt)]
['this', 'example', 'foo # strict', 'thing']
+
"""
lines = yield_lines(strs)
pip_argument = None
@@ -149,6 +151,7 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
>>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx<...]
+
"""
assert unfreeze in {"none", "major", "all"}
path = Path(path_dir) / file_name
@@ -165,6 +168,7 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
>>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'...PyTorch Lightning is just organized PyTorch...'
+
"""
path_readme = os.path.join(path_dir, "README.md")
with open(path_readme, encoding="utf-8") as fo:
@@ -244,6 +248,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
"""Load all base requirements from all particular packages and prune duplicates.
>>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements"))
+
"""
requires = [
load_requirements(d, unfreeze="none" if freeze_requirements else "major")
@@ -300,6 +305,7 @@ def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning
'http://pytorch_lightning.ai', \
'from lightning_fabric import __version__', \
'@lightning.ai']
+
"""
out = lines[:]
for source_import, target_import in mapping:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0442ca53c760d..370b18b19fe92 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -60,7 +60,8 @@ repos:
rev: v1.7.3
hooks:
- id: docformatter
- args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
+ additional_dependencies: [tomli]
+ args: ["--in-place"]
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
diff --git a/docs/source-app/examples/file_server/app.py b/docs/source-app/examples/file_server/app.py
index 3afba48469dd6..ed14df22d97af 100644
--- a/docs/source-app/examples/file_server/app.py
+++ b/docs/source-app/examples/file_server/app.py
@@ -23,6 +23,7 @@ def __init__(
drive: The drive can share data inside your application.
base_dir: The local directory where the data will be stored.
chunk_size: The quantity of bytes to download/upload at once.
+
"""
super().__init__(
cloud_build_config=L.BuildConfig(["flask, flask-cors"]),
@@ -238,4 +239,5 @@ def test_file_server_in_cloud():
# 2. By calling logs = get_logs_fn(),
# you get all the logs currently on the admin page.
+
"""
diff --git a/docs/source-app/examples/github_repo_runner/app.py b/docs/source-app/examples/github_repo_runner/app.py
index 486055f7a877c..b670baa4bf444 100644
--- a/docs/source-app/examples/github_repo_runner/app.py
+++ b/docs/source-app/examples/github_repo_runner/app.py
@@ -36,6 +36,7 @@ def __init__(
script_args: The arguments to be provided to the script.
requirements: The python requirements tp run the script.
cloud_compute: The object to select the cloud instance.
+
"""
super().__init__(
script_path=script_path,
diff --git a/docs/source-app/examples/model_server_app/locust_component.py b/docs/source-app/examples/model_server_app/locust_component.py
index 4351506f5f669..432336adf83b3 100644
--- a/docs/source-app/examples/model_server_app/locust_component.py
+++ b/docs/source-app/examples/model_server_app/locust_component.py
@@ -10,6 +10,7 @@ def __init__(self, num_users: int = 100):
Arguments:
num_users: Number of users emulated by Locust
+
"""
# Note: Using the default port 8089 of Locust.
super().__init__(
diff --git a/docs/source-app/examples/model_server_app/model_server.py b/docs/source-app/examples/model_server_app/model_server.py
index 8562c63d8c9f1..f571f613de25e 100644
--- a/docs/source-app/examples/model_server_app/model_server.py
+++ b/docs/source-app/examples/model_server_app/model_server.py
@@ -18,6 +18,7 @@ class MLServer(LightningWork):
Example: "mlserver_sklearn.SKLearnModel".
Learn more here: $ML_SERVER_URL/tree/master/runtimes
workers: Number of server worker.
+
"""
def __init__(
@@ -51,6 +52,7 @@ def run(self, model_path: Path):
Arguments:
model_path: The path to the trained model.
+
"""
# 1: Use the host and port at runtime so it works in the cloud.
# $ML_SERVER_URL/blob/master/mlserver/settings.py#L50
diff --git a/examples/app/hpo/utils.py b/examples/app/hpo/utils.py
index a08fda2f61dff..9d27c726b0e5e 100644
--- a/examples/app/hpo/utils.py
+++ b/examples/app/hpo/utils.py
@@ -15,6 +15,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
Usage:
download_file('http://web4host.net/5MB.zip')
+
"""
if url == "NEED_TO_BE_CREATED":
raise NotImplementedError
diff --git a/examples/app/layout/app.py b/examples/app/layout/app.py
index 7048f62a94aee..0e9efabba7960 100644
--- a/examples/app/layout/app.py
+++ b/examples/app/layout/app.py
@@ -5,6 +5,7 @@
lightning run app examples/layout/demo.py
This starts one server for each flow that returns a UI. Access the UI at the link printed in the terminal.
+
"""
import os
diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py
index 69895b6498539..3e991de74e2cd 100644
--- a/examples/fabric/build_your_own_trainer/trainer.py
+++ b/examples/fabric/build_your_own_trainer/trainer.py
@@ -137,6 +137,7 @@ def fit(
If not specified, no validation will run.
ckpt_path: Path to previous checkpoints to resume training from.
If specified, will always look for the latest checkpoint within the given directory.
+
"""
self.fabric.launch()
@@ -207,6 +208,7 @@ def train_loop(
If greater then the number of batches in the ``train_loader``, this has no effect.
scheduler_cfg: The learning rate scheduler configuration.
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
+
"""
self.fabric.call("on_train_epoch_start")
iterable = self.progbar_wrapper(
@@ -268,6 +270,7 @@ def val_loop(
val_loader: The dataloader yielding the validation batches.
limit_batches: Limits the batches during this validation epoch.
If greater then the number of batches in the ``val_loader``, this has no effect.
+
"""
# no validation if val_loader wasn't passed
if val_loader is None:
@@ -311,13 +314,14 @@ def val_loop(
torch.set_grad_enabled(True)
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
- """A single training step, running forward and backward. The optimizer step is called separately, as this
- is given as a closure to the optimizer step.
+ """A single training step, running forward and backward. The optimizer step is called separately, as this is
+ given as a closure to the optimizer step.
Args:
model: the lightning module to train
batch: the batch to run the forward on
batch_idx: index of the current batch w.r.t the current epoch
+
"""
outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx)
@@ -347,6 +351,7 @@ def step_scheduler(
Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values.
level: whether we are trying to step on epoch- or step-level
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
+
"""
# no scheduler
@@ -395,6 +400,7 @@ def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any):
Args:
iterable: the iterable to wrap with tqdm
total: the total length of the iterable, necessary in case the number of batches was limited.
+
"""
if self.fabric.is_global_zero:
return tqdm(iterable, total=total, **kwargs)
@@ -406,6 +412,7 @@ def load(self, state: Optional[Mapping], path: str) -> None:
Args:
state: a mapping contaning model, optimizer and lr scheduler
path: the path to load the checkpoint from
+
"""
if state is None:
state = {}
@@ -458,6 +465,7 @@ def _parse_optimizers_schedulers(
Args:
configure_optim_output: The output of ``configure_optimizers``.
For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
+
"""
_lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"}
@@ -511,6 +519,7 @@ def _format_iterable(
prog_bar: a progressbar (on global rank zero) or an iterable (every other rank).
candidates: the values to add as postfix strings to the progressbar.
prefix: the prefix to add to each of these values.
+
"""
if isinstance(prog_bar, tqdm) and candidates is not None:
postfix_str = ""
diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py
index 5f4d9313c6723..c05a35fc8da91 100644
--- a/examples/fabric/image_classifier/train_fabric.py
+++ b/examples/fabric/image_classifier/train_fabric.py
@@ -25,6 +25,7 @@
Accelerate your training loop by setting the ``--accelerator``, ``--strategy``, ``--devices`` options directly from
the command line. See ``lightning run model --help`` or learn more from the documentation:
https://lightning.ai/docs/fabric.
+
"""
import argparse
diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py
index 006397f8e9e65..377579fccde71 100644
--- a/examples/pytorch/basics/autoencoder.py
+++ b/examples/pytorch/basics/autoencoder.py
@@ -14,6 +14,7 @@
"""MNIST autoencoder example.
To run: python autoencoder.py --trainer.max_epochs=50
+
"""
from os import path
from typing import Optional, Tuple
diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py
index 65cf036f7022e..589f632bace59 100644
--- a/examples/pytorch/basics/backbone_image_classifier.py
+++ b/examples/pytorch/basics/backbone_image_classifier.py
@@ -14,6 +14,7 @@
"""MNIST backbone image classifier example.
To run: python backbone_image_classifier.py --trainer.max_epochs=50
+
"""
from os import path
from typing import Optional
diff --git a/examples/pytorch/basics/profiler_example.py b/examples/pytorch/basics/profiler_example.py
index 0c429d29170d4..5fe4004946c29 100644
--- a/examples/pytorch/basics/profiler_example.py
+++ b/examples/pytorch/basics/profiler_example.py
@@ -20,6 +20,7 @@
* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin)
1. pip install tensorboard torch-tb-profiler
2. tensorboard --logdir={FOLDER}
+
"""
from os import path
diff --git a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py
index 4bfd9de3840f1..f55e6aa73a4b9 100644
--- a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py
+++ b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a
-pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the
-'cats and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`,
-see below) is trained for 15 epochs.
+pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats
+and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see
+below) is trained for 15 epochs.
The training consists of three stages.
@@ -37,6 +37,7 @@
To run:
python computer_vision_fine_tuning.py fit
+
"""
import logging
@@ -97,6 +98,7 @@ def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, bat
dl_path: root directory where to download the data
num_workers: number of CPU workers
batch_size: number of sample in a batch
+
"""
super().__init__()
@@ -174,6 +176,7 @@ def __init__(
milestones: List of two epochs milestones
lr: Initial learning rate
lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
+
"""
super().__init__()
self.backbone = backbone
@@ -209,6 +212,7 @@ def forward(self, x):
"""Forward pass.
Returns logits.
+
"""
# 1. Feature extraction:
x = self.feature_extractor(x)
diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py
index 734d625629e0b..e31dec12446dc 100644
--- a/examples/pytorch/domain_templates/generative_adversarial_net.py
+++ b/examples/pytorch/domain_templates/generative_adversarial_net.py
@@ -16,6 +16,7 @@
After a few epochs, launch TensorBoard to see the images being generated at every batch:
tensorboard --logdir default
+
"""
from argparse import ArgumentParser, Namespace
diff --git a/examples/pytorch/domain_templates/imagenet.py b/examples/pytorch/domain_templates/imagenet.py
index 0d7275f58ddb7..553b500e09f62 100644
--- a/examples/pytorch/domain_templates/imagenet.py
+++ b/examples/pytorch/domain_templates/imagenet.py
@@ -28,6 +28,7 @@
python imagenet.py --help
python imagenet.py fit --help
+
"""
import os
from typing import Optional
diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py
index 0f3e455b736c9..3d1e5d016190b 100644
--- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py
+++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py
@@ -29,6 +29,7 @@
[1] https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
Second-Edition/blob/master/Chapter06/02_dqn_pong.py
+
"""
import argparse
@@ -54,6 +55,7 @@ class DQN(nn.Module):
DQN(
(net): Sequential(...)
)
+
"""
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
@@ -79,6 +81,7 @@ class ReplayBuffer:
>>> ReplayBuffer(5) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.ReplayBuffer object at ...>
+
"""
def __init__(self, capacity: int) -> None:
@@ -96,6 +99,7 @@ def append(self, experience: Experience) -> None:
Args:
experience: tuple (state, action, reward, done, new_state)
+
"""
self.buffer.append(experience)
@@ -117,6 +121,7 @@ class RLDataset(IterableDataset):
>>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.RLDataset object at ...>
+
"""
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
@@ -141,6 +146,7 @@ class Agent:
>>> buffer = ReplayBuffer(10)
>>> Agent(env, buffer) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.Agent object at ...>
+
"""
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
@@ -168,6 +174,7 @@ def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
Returns:
action
+
"""
if np.random.random() < epsilon:
action = self.env.action_space.sample()
@@ -194,6 +201,7 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -
Returns:
reward, done
+
"""
action = self.get_action(net, epsilon, device)
@@ -222,6 +230,7 @@ class DQNLightning(LightningModule):
(net): Sequential(...)
)
)
+
"""
def __init__(
@@ -270,6 +279,7 @@ def populate(self, steps: int = 1000) -> None:
Args:
steps: number of random steps to populate the buffer with
+
"""
for i in range(steps):
self.agent.play_step(self.net, epsilon=1.0)
@@ -282,6 +292,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
q values
+
"""
return self.net(x)
@@ -293,6 +304,7 @@ def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor
Returns:
loss
+
"""
states, actions, rewards, dones, next_states = batch
@@ -308,8 +320,8 @@ def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor
return nn.MSELoss()(state_action_values, expected_state_action_values)
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
- """Carries out a single step through the environment to update the replay buffer. Then calculates loss
- based on the minibatch received.
+ """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
+ the minibatch received.
Args:
batch: current mini batch of replay data
@@ -317,6 +329,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O
Returns:
Training loss and log metrics
+
"""
device = self.get_device(batch)
epsilon = max(self.eps_end, self.eps_start - (self.global_step + 1) / self.eps_last_frame)
diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py
index b68fcf720bb6e..16aa5ebc8675c 100644
--- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py
+++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py
@@ -26,6 +26,7 @@
[1] https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py
[2] https://github.com/openai/spinningup
[3] https://github.com/sid-sundrani/ppo_lightning
+
"""
import argparse
from typing import Callable, Iterator, List, Tuple
@@ -52,8 +53,7 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128):
class ActorCategorical(nn.Module):
- """Policy network, for discrete action spaces, which returns a distribution and an action given an
- observation."""
+ """Policy network, for discrete action spaces, which returns a distribution and an action given an observation."""
def __init__(self, actor_net):
"""
@@ -81,6 +81,7 @@ def get_log_prob(self, pi: Categorical, actions: torch.Tensor):
Returns:
log probability of the action under pi
+
"""
return pi.log_prob(actions)
@@ -117,6 +118,7 @@ def get_log_prob(self, pi: Normal, actions: torch.Tensor):
Returns:
log probability of the action under pi
+
"""
return pi.log_prob(actions).sum(axis=-1)
@@ -127,6 +129,7 @@ class ExperienceSourceDataset(IterableDataset):
Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the
experience source and how the batch is generated is defined the Lightning model itself
+
"""
def __init__(self, generate_batch: Callable):
@@ -144,6 +147,7 @@ class PPOLightning(LightningModule):
Train:
trainer = Trainer()
trainer.fit(model)
+
"""
def __init__(
@@ -231,6 +235,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te
Returns:
Tuple of policy and action
+
"""
pi, action = self.actor(x)
value = self.critic(x)
@@ -245,6 +250,7 @@ def discount_rewards(self, rewards: List[float], discount: float) -> List[float]
Returns:
list of discounted rewards/advantages
+
"""
assert isinstance(rewards[0], float)
@@ -267,6 +273,7 @@ def calc_advantage(self, rewards: List[float], values: List[float], last_value:
Returns:
list of advantages
+
"""
rews = rewards + [last_value]
vals = values + [last_value]
@@ -373,6 +380,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor]):
Args:
batch: batch of replay buffer/trajectory data
+
"""
state, action, old_logp, qval, adv = batch
@@ -405,8 +413,7 @@ def configure_optimizers(self) -> List[Optimizer]:
return optimizer_actor, optimizer_critic
def optimizer_step(self, *args, **kwargs):
- """Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data
- sample."""
+ """Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data sample."""
for _ in range(self.nb_optim_iters):
super().optimizer_step(*args, **kwargs)
diff --git a/examples/pytorch/domain_templates/semantic_segmentation.py b/examples/pytorch/domain_templates/semantic_segmentation.py
index eb816756ce6c4..286084d5ff084 100644
--- a/examples/pytorch/domain_templates/semantic_segmentation.py
+++ b/examples/pytorch/domain_templates/semantic_segmentation.py
@@ -31,8 +31,7 @@
def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
- """Create synthetic dataset with random images, just to simulate that the dataset have been already
- downloaded."""
+ """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
for p_dir in (path_dir_images, path_dir_masks):
@@ -65,6 +64,7 @@ class KITTI(Dataset):
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
+
"""
IMAGE_PATH = os.path.join("training", "image_2")
@@ -154,6 +154,7 @@ class UNet(nn.Module):
(5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
+
"""
def __init__(self, num_classes: int = 19, num_layers: int = 5, features_start: int = 64, bilinear: bool = False):
@@ -200,6 +201,7 @@ class DoubleConv(nn.Module):
DoubleConv(
(net): Sequential(...)
)
+
"""
def __init__(self, in_ch: int, out_ch: int):
@@ -229,6 +231,7 @@ class Down(nn.Module):
)
)
)
+
"""
def __init__(self, in_ch: int, out_ch: int):
@@ -240,8 +243,8 @@ def forward(self, x):
class Up(nn.Module):
- """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature
- map from contracting path, followed by double 3x3 convolution.
+ """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map
+ from contracting path, followed by double 3x3 convolution.
>>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Up(
@@ -250,6 +253,7 @@ class Up(nn.Module):
(net): Sequential(...)
)
)
+
"""
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
@@ -306,6 +310,7 @@ class SegModel(LightningModule):
)
)
)
+
"""
def __init__(
diff --git a/pyproject.toml b/pyproject.toml
index 8aa9903f36114..27334b81220e6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,6 +50,13 @@ skip = ["_notebooks"]
line-length = 120
exclude = '(_notebooks/.*)'
+[tool.docformatter]
+recursive = true
+# this need to be shorter as some docstings are r"""...
+wrap-summaries = 119
+wrap-descriptions = 120
+blank = true
+
[tool.ruff]
line-length = 120
diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py
index 3dd2b8d642f5d..5e6f9ba3dd350 100644
--- a/requirements/collect_env_details.py
+++ b/requirements/collect_env_details.py
@@ -14,6 +14,7 @@
"""Diagnose your system and show basic information.
This server mainly to get detail info for better bug reporting.
+
"""
import os
diff --git a/setup.py b/setup.py
index 043308ebb683c..d5d92d8228ab1 100755
--- a/setup.py
+++ b/setup.py
@@ -38,6 +38,7 @@
compared against PyPI registry
b) with a parameterization build desired packages in to standard `dist/` folder
c) validate packages and publish to PyPI
+
"""
import contextlib
import glob
diff --git a/src/lightning/app/api/http_methods.py b/src/lightning/app/api/http_methods.py
index 8cab27096b657..dc3dc32bee952 100644
--- a/src/lightning/app/api/http_methods.py
+++ b/src/lightning/app/api/http_methods.py
@@ -58,6 +58,7 @@ def request(self, request: Request) -> OutputRequestModel:
def configure_api(self):
return [Post("/api/v1/request", self.request)]
+
"""
_body: Optional[str] = None
@@ -116,6 +117,7 @@ def __init__(
route: The path used to route the requests
method: The associated flow method
timeout: The time in seconds taken before raising a timeout exception.
+
"""
self.route = route
self.attached_to_flow = hasattr(method, "__self__")
diff --git a/src/lightning/app/cli/cmd_install.py b/src/lightning/app/cli/cmd_install.py
index 05ead42b42c3d..b43aa3f88fac9 100644
--- a/src/lightning/app/cli/cmd_install.py
+++ b/src/lightning/app/cli/cmd_install.py
@@ -566,6 +566,7 @@ def _install_app_from_source(
If true, overwrite the app directory without asking if it already exists
git_sha:
The git_sha for checking out the git repo of the app.
+
"""
if not cwd:
diff --git a/src/lightning/app/cli/commands/logs.py b/src/lightning/app/cli/commands/logs.py
index eba1746cbdd2d..4587987ae5f17 100644
--- a/src/lightning/app/cli/commands/logs.py
+++ b/src/lightning/app/cli/commands/logs.py
@@ -46,6 +46,7 @@ def logs(app_name: str, components: List[str], follow: bool) -> None:
Print logs only from selected works:
$ lightning show logs my-application root.work_a root.work_b
+
"""
_show_logs(app_name, components, follow)
diff --git a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py b/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py
index e1b30e1c11b6b..8c7dad2fe76ea 100644
--- a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py
+++ b/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py
@@ -2,6 +2,7 @@
1. Init the component.
2. call .run()
+
"""
from placeholdername.component import TemplateComponent
diff --git a/src/lightning/app/cli/connect/app.py b/src/lightning/app/cli/connect/app.py
index 76de20256f7b4..ebad9b12973f7 100644
--- a/src/lightning/app/cli/connect/app.py
+++ b/src/lightning/app/cli/connect/app.py
@@ -57,6 +57,7 @@ def connect_app(app_name_or_id: str):
\b
# once done, disconnect and go back to the standard lightning CLI commands
lightning disconnect
+
"""
from lightning.app.utilities.commands.base import _download_command
diff --git a/src/lightning/app/cli/lightning_cli_delete.py b/src/lightning/app/cli/lightning_cli_delete.py
index c48bee07ca628..179e5b6fc365d 100644
--- a/src/lightning/app/cli/lightning_cli_delete.py
+++ b/src/lightning/app/cli/lightning_cli_delete.py
@@ -98,6 +98,7 @@ def delete_app(app_name: str, skip_user_confirm_prompt: bool) -> None:
Deleting an app also deletes all app websites, works, artifacts, and logs. This permanently removes any record of
the app as well as all any of its associated resources and data. This does not affect any resources and data
associated with other Lightning apps on your account.
+
"""
console = Console()
diff --git a/src/lightning/app/cli/lightning_cli_launch.py b/src/lightning/app/cli/lightning_cli_launch.py
index 8cf56453d86f9..c171fd7b946f1 100644
--- a/src/lightning/app/cli/lightning_cli_launch.py
+++ b/src/lightning/app/cli/lightning_cli_launch.py
@@ -40,10 +40,11 @@ def launch() -> None:
@click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str)
@click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int)
def run_server(file: str, queue_id: str, host: str, port: int) -> None:
- """It takes the application file as input, build the application object and then use that to run the
- application server.
+ """It takes the application file as input, build the application object and then use that to run the application
+ server.
This is used by the cloud runners to start the status server for the application
+
"""
logger.debug(f"Run Server: {file} {queue_id} {host} {port}")
start_application_server(file, host, port, queue_id=queue_id)
@@ -54,10 +55,11 @@ def run_server(file: str, queue_id: str, host: str, port: int) -> None:
@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
@click.option("--base-url", help="Base url at which the app server is hosted", default="")
def run_flow(file: str, queue_id: str, base_url: str) -> None:
- """It takes the application file as input, build the application object, proxy all the work components and then
- run the application flow defined in the root component.
+ """It takes the application file as input, build the application object, proxy all the work components and then run
+ the application flow defined in the root component.
It does exactly what a singleprocess dispatcher would do but with proxied work components.
+
"""
logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
run_lightning_flow(file, queue_id=queue_id, base_url=base_url)
@@ -68,8 +70,8 @@ def run_flow(file: str, queue_id: str, base_url: str) -> None:
@click.option("--work-name", type=str)
@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
def run_work(file: str, work_name: str, queue_id: str) -> None:
- """Unlike other entrypoints, this command will take the file path or module details for a work component and
- run that by fetching the states from the queues."""
+ """Unlike other entrypoints, this command will take the file path or module details for a work component and run
+ that by fetching the states from the queues."""
logger.debug(f"Run Work: {file} {work_name} {queue_id}")
run_lightning_work(
file=file,
@@ -109,10 +111,11 @@ def run_flow_and_servers(
port: int,
flow_port: Tuple[Tuple[str, int]],
) -> None:
- """It takes the application file as input, build the application object and then use that to run the
- application flow defined in the root component, the application server and all the flow frontends.
+ """It takes the application file as input, build the application object and then use that to run the application
+ flow defined in the root component, the application server and all the flow frontends.
This is used by the cloud runners to start the flow, the status server and all frontends for the application
+
"""
logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
logger.debug(f"Run Server: {file} {queue_id} {host} {port}.")
diff --git a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py b/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py
index 6f5b2eb563c6d..a2935140a2eff 100644
--- a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py
+++ b/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py
@@ -13,6 +13,7 @@ def __init__(self, log_dir: Path, sync_every_n_seconds: int = 5) -> None:
Args:
log_dir: The path to the directory where the TensorBoard log-files will appear.
sync_every_n_seconds: How often to sync the log directory (given as an argument to the run method)
+
"""
super().__init__()
self.worker = TensorBoardWorker(log_dir=log_dir, sync_every_n_seconds=sync_every_n_seconds)
diff --git a/src/lightning/app/components/database/client.py b/src/lightning/app/components/database/client.py
index 01643afbfe73f..81f0862918934 100644
--- a/src/lightning/app/components/database/client.py
+++ b/src/lightning/app/components/database/client.py
@@ -29,6 +29,7 @@ def _configure_session() -> Session:
"""Configures the session for GET and POST requests.
It enables a generous retrial strategy that waits for the application server to connect.
+
"""
retry_strategy = Retry(
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
diff --git a/src/lightning/app/components/database/server.py b/src/lightning/app/components/database/server.py
index 6da7710cfa4f0..3fbf75f01ac85 100644
--- a/src/lightning/app/components/database/server.py
+++ b/src/lightning/app/components/database/server.py
@@ -158,6 +158,7 @@ class CounterModel(SQLModel, table=True):
# RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
+
"""
super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
self.db_filename = db_filename
diff --git a/src/lightning/app/components/database/utilities.py b/src/lightning/app/components/database/utilities.py
index dd31c12da6192..e4561d9245c2e 100644
--- a/src/lightning/app/components/database/utilities.py
+++ b/src/lightning/app/components/database/utilities.py
@@ -52,6 +52,7 @@ def _pydantic_column_type(pydantic_type: Any) -> Any:
class TrialConfig(SQLModel, table=False):
...
params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float]))
+
"""
class PydanticJSONType(TypeDecorator, Generic[T]):
diff --git a/src/lightning/app/components/multi_node/base.py b/src/lightning/app/components/multi_node/base.py
index a300918452d48..1da8ef34e390f 100644
--- a/src/lightning/app/components/multi_node/base.py
+++ b/src/lightning/app/components/multi_node/base.py
@@ -67,6 +67,7 @@ def run(
running locally.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
+
"""
super().__init__()
if num_nodes > 1 and not is_running_in_cloud():
diff --git a/src/lightning/app/components/python/popen.py b/src/lightning/app/components/python/popen.py
index 34b50140c5d98..9e585e2c214f8 100644
--- a/src/lightning/app/components/python/popen.py
+++ b/src/lightning/app/components/python/popen.py
@@ -70,6 +70,7 @@ def __init__(
.. literalinclude:: ../../../../examples/app/components/python/component_popen.py
:language: python
+
"""
super().__init__(**kwargs)
if not os.path.exists(script_path):
diff --git a/src/lightning/app/components/python/tracer.py b/src/lightning/app/components/python/tracer.py
index e0bde7c0b5e09..9048bceaec073 100644
--- a/src/lightning/app/components/python/tracer.py
+++ b/src/lightning/app/components/python/tracer.py
@@ -117,6 +117,7 @@ def __init__(
.. literalinclude:: ../../../../examples/app/components/python/app.py
:language: python
+
"""
super().__init__(**kwargs)
self.script_path = str(script_path)
diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py
index 165a6422e45e4..f0838398a1aa3 100644
--- a/src/lightning/app/components/serve/auto_scaler.py
+++ b/src/lightning/app/components/serve/auto_scaler.py
@@ -114,8 +114,8 @@ async def num_requests() -> int:
class _LoadBalancer(LightningWork):
- r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton
- API asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
+ r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
+ asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
After enabling you will require to send username and password from the request header for the private endpoints.
@@ -131,6 +131,7 @@ class _LoadBalancer(LightningWork):
api_name: The name to be displayed on the UI. Normally, it is the name of the work class
cold_start_proxy: The proxy service to use while the work is cold starting.
**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
+
"""
@requires(["aiohttp"])
@@ -236,6 +237,7 @@ async def consumer(self):
Two instances of this function should not be running with shared `_state_server` as that would create race
conditions
+
"""
while True:
await asyncio.sleep(0.05)
@@ -293,6 +295,7 @@ def _has_processing_capacity(self):
"""This function checks if we have processing capacity for one more request or not.
Depends on the value from here, we decide whether we should proxy the request or not
+
"""
if not self._fastapi_app:
return False
@@ -385,6 +388,7 @@ def update_servers(self, server_works: List[LightningWork]):
"""Updates works that load balancer distributes requests to.
AutoScaler uses this method to increase/decrease the number of works.
+
"""
old_server_urls = set(self.servers)
current_server_urls = {
@@ -623,6 +627,7 @@ def add_work(self, work) -> str:
Returns:
The name of the new work attribute.
+
"""
work_attribute = uuid.uuid4().hex
work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}"
@@ -665,6 +670,7 @@ def scale(self, replicas: int, metrics: dict) -> int:
Returns:
The target number of running works. The value will be adjusted after this method runs
so that it satisfies ``min_replicas<=replicas<=max_replicas``.
+
"""
pending_requests = metrics["pending_requests"]
active_or_pending_works = replicas + metrics["pending_works"]
diff --git a/src/lightning/app/components/serve/cold_start_proxy.py b/src/lightning/app/components/serve/cold_start_proxy.py
index 1be5315f76cc8..6ab829d87571f 100644
--- a/src/lightning/app/components/serve/cold_start_proxy.py
+++ b/src/lightning/app/components/serve/cold_start_proxy.py
@@ -35,6 +35,7 @@ class ColdStartProxy:
Args:
proxy_url (str): The url of the proxy service
+
"""
@requires(["aiohttp"])
@@ -46,12 +47,13 @@ def __init__(self, proxy_url: str):
async def handle_request(self, request: BaseModel) -> Any:
"""This method is called when the request is received while the work is cold starting. The default
- implementation of this method is to forward the request body to the proxy service with POST method but the
- user can override this method to handle the request in any way.
+ implementation of this method is to forward the request body to the proxy service with POST method but the user
+ can override this method to handle the request in any way.
Args:
request (BaseModel): The request body, a pydantic model that is being
forwarded by load balancer which is a FastAPI service
+
"""
try:
async with aiohttp.ClientSession() as session:
diff --git a/src/lightning/app/components/serve/gradio_server.py b/src/lightning/app/components/serve/gradio_server.py
index dc9f2d7847415..4dd6ae8139fed 100644
--- a/src/lightning/app/components/serve/gradio_server.py
+++ b/src/lightning/app/components/serve/gradio_server.py
@@ -77,6 +77,7 @@ def build_model(self) -> Any:
"""Override to instantiate and return your model.
The model would be accessible under self.model
+
"""
def run(self, *args: Any, **kwargs: Any):
diff --git a/src/lightning/app/components/serve/python_server.py b/src/lightning/app/components/serve/python_server.py
index 9e70688725fe2..518c296285218 100644
--- a/src/lightning/app/components/serve/python_server.py
+++ b/src/lightning/app/components/serve/python_server.py
@@ -211,6 +211,7 @@ def predict(self, request):
... return {"prediction": self._model(request.image)}
...
>>> app = LightningApp(SimpleServer())
+
"""
super().__init__(parallel=True, **kwargs)
if not issubclass(input_type, BaseModel):
@@ -228,6 +229,7 @@ def setup(self, *args: Any, **kwargs: Any) -> None:
Note that this will be called exactly once on every work machines. So if you have multiple machines for serving,
this will be called on each of them.
+
"""
return
@@ -243,6 +245,7 @@ def predict(self, request: Any) -> Any:
This method must be overriden by the user with the prediction logic. The pre/post processing, actual prediction
using the model(s) etc goes here
+
"""
pass
@@ -325,6 +328,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
Normally, you don't need to override this method.
+
"""
self.setup(*args, **kwargs)
diff --git a/src/lightning/app/components/serve/serve.py b/src/lightning/app/components/serve/serve.py
index 5cf1ad6dbcee4..0ae40302293f3 100644
--- a/src/lightning/app/components/serve/serve.py
+++ b/src/lightning/app/components/serve/serve.py
@@ -67,6 +67,7 @@ def __init__(
host: Address to be used to serve the model.
port: Port to be used to serve the model.
workers: Number of workers for the uvicorn. Warning, this won't work if your subclass takes more arguments.
+
"""
super().__init__(parallel=True, host=host, port=port)
if input and input not in _DESERIALIZER:
@@ -151,8 +152,8 @@ def configure_layout(self) -> str:
def _maybe_create_instance() -> Optional[ModelInferenceAPI]:
- """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi
- workers are present."""
+ """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi workers
+ are present."""
render_fn_name = os.getenv("LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME", None)
render_fn_module_file = os.getenv("LIGHTNING_MODEL_INFERENCE_API_FILE", None)
if render_fn_name is None or render_fn_module_file is None:
diff --git a/src/lightning/app/components/serve/streamlit.py b/src/lightning/app/components/serve/streamlit.py
index 9c39ce23d55c7..8e9d2d34ae589 100644
--- a/src/lightning/app/components/serve/streamlit.py
+++ b/src/lightning/app/components/serve/streamlit.py
@@ -50,6 +50,7 @@ def build_model(self) -> Any:
"""Optionally override to instantiate and return your model.
The model will be accessible under ``self.model``.
+
"""
return None
diff --git a/src/lightning/app/components/serve/types/type.py b/src/lightning/app/components/serve/types/type.py
index 54d36d8afacc7..157940a60f8e7 100644
--- a/src/lightning/app/components/serve/types/type.py
+++ b/src/lightning/app/components/serve/types/type.py
@@ -28,4 +28,5 @@ def deserialize(self, *args: Any, **kwargs: Any): # pragma: no cover
"""Take the inputs from the network and deserilize/convert them them.
Output from this method will go to the exposed method as arguments.
+
"""
diff --git a/src/lightning/app/components/training.py b/src/lightning/app/components/training.py
index 7c47bea8ebdc3..5ffe843869e99 100644
--- a/src/lightning/app/components/training.py
+++ b/src/lightning/app/components/training.py
@@ -159,6 +159,7 @@ def __init__(
cloud_compute: The cloud compute object used in the cloud.
sanity_serving: Whether to validate that the model correctly implements
the ServableModule API
+
"""
super().__init__()
self.script_path = script_path
diff --git a/src/lightning/app/core/api.py b/src/lightning/app/core/api.py
index d4f076005373e..b34073d310a2c 100644
--- a/src/lightning/app/core/api.py
+++ b/src/lightning/app/core/api.py
@@ -270,8 +270,8 @@ async def post_delta(
x_lightning_session_uuid: Optional[str] = Header(None), # type: ignore[assignment]
x_lightning_session_id: Optional[str] = Header(None), # type: ignore[assignment]
) -> Optional[Dict]:
- """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to
- update the state."""
+ """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update
+ the state."""
if x_lightning_session_uuid is None:
raise Exception("Missing X-Lightning-Session-UUID header")
diff --git a/src/lightning/app/core/app.py b/src/lightning/app/core/app.py
index 778404bf47c94..e3116eca3784e 100644
--- a/src/lightning/app/core/app.py
+++ b/src/lightning/app/core/app.py
@@ -383,8 +383,7 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque
return deltas
def maybe_apply_changes(self) -> Optional[bool]:
- """Get the deltas from both the flow queue and the work queue, merge the two deltas and update the
- state."""
+ """Get the deltas from both the flow queue and the work queue, merge the two deltas and update the state."""
self._send_flow_to_work_deltas(self.state)
if not self.collect_changes:
@@ -503,6 +502,7 @@ def _run(self) -> bool:
"""Entry point of the LightningApp.
This would be dispatched by the Runtime objects.
+
"""
self._original_state = deepcopy(self.state)
done = False
diff --git a/src/lightning/app/core/flow.py b/src/lightning/app/core/flow.py
index ff21a010070bf..c600f080158f8 100644
--- a/src/lightning/app/core/flow.py
+++ b/src/lightning/app/core/flow.py
@@ -396,6 +396,7 @@ def _exit(self, end_msg: str = "") -> None:
.. deprecated:: 1.9.0
This function is deprecated and will be removed in 2.0.0. Use :meth:`stop` instead.
+
"""
warnings.warn(
DeprecationWarning(
@@ -411,6 +412,7 @@ def _is_state_attribute(name: str) -> bool:
(prefixed by '__') attributes are not.
Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
+
"""
return name in LightningFlow._INTERNAL_STATE_VARS or not name.startswith("_")
@@ -487,6 +489,7 @@ def run(self):
+
"""
if not user_key:
frame = cast(FrameType, inspect.currentframe()).f_back
@@ -626,6 +629,7 @@ def configure_layout(self):
+
"""
return [{"name": name, "content": component} for (name, component) in self.flows.items()]
@@ -639,6 +643,7 @@ def experimental_iterate(self, iterable: Iterable, run_once: bool = True, user_k
run_once: Whether to run the entire iteration only once.
Otherwise, it would restart from the beginning.
user_key: Key to be used to track the caching mechanism.
+
"""
if not isinstance(iterable, Iterable):
raise TypeError(f"An iterable should be provided to `self.iterate` method. Found {iterable}")
@@ -708,6 +713,7 @@ def my_remote_method(self, name):
.. code-block:: bash
lightning my_command_name --args name=my_own_name
+
"""
raise NotImplementedError
@@ -741,6 +747,7 @@ def configure_api(self):
Once the app is running, you can access the Swagger UI of the app
under the ``/docs`` route.
+
"""
raise NotImplementedError
@@ -805,6 +812,7 @@ def load_state_dict(self, flow_state, children_states, strict) -> None:
children_states: The state of the dynamic children of this flow.
strict: Whether to raise an exception if a dynamic
children hasn't been re-created.
+
"""
self.set_state(flow_state, recurse=False)
direct_children_states = {k: v for k, v in children_states.items() if "." not in k}
diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py
index 18e02dd989b2e..1900f961a7ab1 100644
--- a/src/lightning/app/core/queues.py
+++ b/src/lightning/app/core/queues.py
@@ -182,6 +182,7 @@ def get(self, timeout: Optional[float] = None) -> Any:
timeout:
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
A timeout of None can be used to block indefinitely.
+
"""
pass
@@ -190,6 +191,7 @@ def is_running(self) -> bool:
"""Returns True if the queue is running, False otherwise.
Child classes should override this property and implement custom logic as required
+
"""
return True
@@ -286,6 +288,7 @@ def get(self, timeout: Optional[float] = None) -> Any:
timeout:
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
A timeout of None can be used to block indefinitely.
+
"""
if timeout is None:
# this means it's blocking in redis
@@ -464,6 +467,7 @@ def _split_app_id_and_queue_name(queue_name: str) -> Tuple[str, str]:
This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be
accurate. Remove this eventually and let the Queue class take app id and name of the queue as arguments
+
"""
if "_" not in queue_name:
return "", queue_name
diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py
index f5416c0cfae3d..15a762d6f9ad2 100644
--- a/src/lightning/app/core/work.py
+++ b/src/lightning/app/core/work.py
@@ -124,6 +124,7 @@ def __init__(
+
"""
from lightning.app.runners.backends.backend import Backend
@@ -212,6 +213,7 @@ def internal_ip(self) -> str:
By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
Locally, the address is 127.0.0.1 and in the cloud it will be determined by the cluster.
+
"""
return self._internal_ip
@@ -221,6 +223,7 @@ def public_ip(self) -> str:
By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster.
+
"""
return self._public_ip
@@ -234,6 +237,7 @@ def _is_state_attribute(name: str) -> bool:
(prefixed by '__') attributes are not.
Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
+
"""
return name in LightningWork._INTERNAL_STATE_VARS or not name.startswith("_")
@@ -247,6 +251,7 @@ def display_name(self) -> str:
"""Returns the display name of the LightningWork in the cloud.
The display name needs to set before the run method of the work is called.
+
"""
return self._display_name
@@ -269,6 +274,7 @@ def parallel(self) -> bool:
"""Whether to run in parallel mode or not.
When parallel is False, the flow waits for the work to finish.
+
"""
return self._parallel
@@ -325,6 +331,7 @@ def status(self) -> WorkStatus:
"""Return the current status of the work.
All statuses are stored in the state.
+
"""
call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
if call_hash in self._calls:
@@ -628,6 +635,7 @@ def run(self, *args: Any, **kwargs: Any) -> None:
Raises:
LightningPlatformException: If resource exceeds platform quotas or other constraints.
+
"""
def on_exception(self, exception: BaseException) -> None:
@@ -636,8 +644,7 @@ def on_exception(self, exception: BaseException) -> None:
raise exception
def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus:
- """Method used to return the first request and the total count of timeout after the latest succeeded
- status."""
+ """Method used to return the first request and the total count of timeout after the latest succeeded status."""
succeeded_statuses = [
status_idx for status_idx, status in enumerate(statuses) if status["stage"] == WorkStageStatus.SUCCEEDED
]
@@ -653,6 +660,7 @@ def on_exit(self) -> None:
"""Override this hook to add your logic when the work is exiting.
Note: This hook is not guaranteed to be called when running in the cloud.
+
"""
pass
@@ -660,6 +668,7 @@ def stop(self) -> None:
"""Stops LightingWork component and shuts down hardware provisioned via L.CloudCompute.
This can only be called from a ``LightningFlow``.
+
"""
if not self._backend:
raise RuntimeError(f"Only the `LightningFlow` can request this work ({self.name!r}) to stop.")
@@ -675,6 +684,7 @@ def delete(self) -> None:
"""Delete LightingWork component and shuts down hardware provisioned via L.CloudCompute.
Locally, the work.delete() behaves as work.stop().
+
"""
if not self._backend:
raise Exception(
@@ -755,4 +765,5 @@ def configure_layout(self):
returned URL can depend on the state. This is not the case if the work returns a
:class:`~lightning.app.frontend.frontend.Frontend`. These need to be provided at the time of app creation
in order for the runtime to start the server.
+
"""
diff --git a/src/lightning/app/frontend/frontend.py b/src/lightning/app/frontend/frontend.py
index d2d24b56561a6..2d87d4ebcd7c6 100644
--- a/src/lightning/app/frontend/frontend.py
+++ b/src/lightning/app/frontend/frontend.py
@@ -23,6 +23,7 @@ class Frontend(ABC):
"""Base class for any frontend that gets exposed by LightningFlows.
The flow attribute will be set by the app while bootstrapping.
+
"""
def __init__(self) -> None:
@@ -48,6 +49,7 @@ def start_server(self, host: str, port: int, root_path: str = "") -> None:
def start_server(self, host, port, root_path=""):
self._process = subprocess.Popen(["flask", "run" "--host", host, "--port", str(port)])
+
"""
@abstractmethod
@@ -62,4 +64,5 @@ def stop_server(self) -> None:
def stop_server(self):
self._process.kill()
+
"""
diff --git a/src/lightning/app/frontend/just_py/just_py.py b/src/lightning/app/frontend/just_py/just_py.py
index 7d5ac44306bd3..11a9d55799544 100644
--- a/src/lightning/app/frontend/just_py/just_py.py
+++ b/src/lightning/app/frontend/just_py/just_py.py
@@ -81,6 +81,7 @@ def webpage():
app = LightningApp(Flow())
+
"""
def __init__(self, render_fn: Callable) -> None:
diff --git a/src/lightning/app/frontend/panel/__init__.py b/src/lightning/app/frontend/panel/__init__.py
index bb67ee156832f..96cb5550ce3c6 100644
--- a/src/lightning/app/frontend/panel/__init__.py
+++ b/src/lightning/app/frontend/panel/__init__.py
@@ -1,5 +1,4 @@
-"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app
-framework."""
+"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app framework."""
from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher
from lightning.app.frontend.panel.panel_frontend import PanelFrontend
diff --git a/src/lightning/app/frontend/panel/app_state_comm.py b/src/lightning/app/frontend/panel/app_state_comm.py
index eb1f0187862d4..9fec245f7b2f9 100644
--- a/src/lightning/app/frontend/panel/app_state_comm.py
+++ b/src/lightning/app/frontend/panel/app_state_comm.py
@@ -94,6 +94,7 @@ def _watch_app_state(callback: Callable):
def handle_state_change():
print("The App State changed.")
watch_app_state(handle_state_change)
+
"""
_CALLBACKS.append(callback)
_start_websocket()
diff --git a/src/lightning/app/frontend/panel/app_state_watcher.py b/src/lightning/app/frontend/panel/app_state_watcher.py
index 528a19accede8..3612fb114775c 100644
--- a/src/lightning/app/frontend/panel/app_state_watcher.py
+++ b/src/lightning/app/frontend/panel/app_state_watcher.py
@@ -72,6 +72,7 @@ def update(state):
Pydantic which additionally provides powerful and unique features for building reactive apps.
Please note the ``AppStateWatcher`` is a singleton, i.e., only one instance is instantiated
+
"""
state: AppState = ClassSelector(
diff --git a/src/lightning/app/frontend/panel/panel_frontend.py b/src/lightning/app/frontend/panel/panel_frontend.py
index f4a5c68f57054..d0f4ead1cae8f 100644
--- a/src/lightning/app/frontend/panel/panel_frontend.py
+++ b/src/lightning/app/frontend/panel/panel_frontend.py
@@ -35,6 +35,7 @@ def _has_panel_autoreload() -> bool:
"""Returns True if the PANEL_AUTORELOAD environment variable is set to 'yes' or 'true'.
Please note the casing of value does not matter
+
"""
return os.environ.get("PANEL_AUTORELOAD", "no").lower() in ["yes", "y", "true"]
@@ -98,6 +99,7 @@ def configure_layout(self):
For development you can get Panel autoreload by setting the ``PANEL_AUTORELOAD``
environment variable to 'yes', i.e. run
``PANEL_AUTORELOAD=yes lightning run app app_basic.py``
+
"""
@requires("panel")
diff --git a/src/lightning/app/frontend/panel/panel_serve_render_fn.py b/src/lightning/app/frontend/panel/panel_serve_render_fn.py
index df6a83d713d59..0c06cccd068db 100644
--- a/src/lightning/app/frontend/panel/panel_serve_render_fn.py
+++ b/src/lightning/app/frontend/panel/panel_serve_render_fn.py
@@ -26,6 +26,7 @@
.. code-block:: bash
python panel_serve_render_fn
+
"""
import inspect
import os
diff --git a/src/lightning/app/frontend/stream_lit.py b/src/lightning/app/frontend/stream_lit.py
index 1ce03b997eadc..0cdc37296d931 100644
--- a/src/lightning/app/frontend/stream_lit.py
+++ b/src/lightning/app/frontend/stream_lit.py
@@ -61,6 +61,7 @@ def my_streamlit_ui(state):
st.write("Hello from streamlit!")
st.write(state.counter)
+
"""
@requires("streamlit")
diff --git a/src/lightning/app/frontend/streamlit_base.py b/src/lightning/app/frontend/streamlit_base.py
index 189bbba82b768..b03628b4495b0 100644
--- a/src/lightning/app/frontend/streamlit_base.py
+++ b/src/lightning/app/frontend/streamlit_base.py
@@ -14,6 +14,7 @@
"""This file gets run by streamlit, which we launch within Lightning.
From here, we will call the render function that the user provided in ``configure_layout``.
+
"""
import os
import pydoc
diff --git a/src/lightning/app/frontend/utils.py b/src/lightning/app/frontend/utils.py
index 1399046686c23..80898f1213d67 100644
--- a/src/lightning/app/frontend/utils.py
+++ b/src/lightning/app/frontend/utils.py
@@ -37,6 +37,7 @@ def _get_flow_state(flow: str) -> AppState:
Returns:
AppState: An AppState scoped to the current Flow.
+
"""
app_state = AppState()
app_state._request_state() # pylint: disable=protected-access
@@ -54,6 +55,7 @@ def _get_frontend_environment(flow: str, render_fn_or_file: Callable | str, port
Returns:
os._Environ: An environment
+
"""
env = os.environ.copy()
env["LIGHTNING_FLOW_NAME"] = flow
diff --git a/src/lightning/app/frontend/web.py b/src/lightning/app/frontend/web.py
index f6fa639d3e252..2e7d9f3f2f8e3 100644
--- a/src/lightning/app/frontend/web.py
+++ b/src/lightning/app/frontend/web.py
@@ -44,6 +44,7 @@ class StaticWebFrontend(Frontend):
def configure_layout(self):
return StaticWebFrontend("path/to/folder/to/serve")
+
"""
def __init__(self, serve_dir: str) -> None:
@@ -102,8 +103,7 @@ def _start_server(
def _get_log_config(log_file: str) -> dict:
- """Returns a logger configuration in the format expected by uvicorn that sends all logs to the given
- logfile."""
+ """Returns a logger configuration in the format expected by uvicorn that sends all logs to the given logfile."""
# Modified from the default config found in uvicorn.config.LOGGING_CONFIG
return {
"version": 1,
diff --git a/src/lightning/app/launcher/launcher.py b/src/lightning/app/launcher/launcher.py
index 8f00731161dfc..3d2a0668891d3 100644
--- a/src/lightning/app/launcher/launcher.py
+++ b/src/lightning/app/launcher/launcher.py
@@ -109,6 +109,7 @@ def run_lightning_work(
It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
specific logic is being implemented here
+
"""
logger.debug(f"Run Lightning Work {file} {work_name} {queue_id}")
@@ -231,6 +232,7 @@ def serve_frontend(file: str, flow_name: str, host: str, port: int):
It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
specific logic is being implemented here.
+
"""
_set_frontend_context()
logger.debug(f"Run Serve Frontend {file} {flow_name} {host} {port}")
@@ -340,15 +342,16 @@ def _sigterm_handler(*_):
def _get_frontends_from_app(entrypoint_file):
- """This function is used to get the frontends from the app. It will be used to start the frontends in a
- separate process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded
- locally to set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081
- if the app.spec.frontends is not set.
+ """This function is used to get the frontends from the app. It will be used to start the frontends in a separate
+ process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded locally to
+ set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081 if the
+ app.spec.frontends is not set.
NOTE: frontend_name are sorted to ensure that they get consistent ports.
:param entrypoint_file: The entrypoint file for the app
:return: A list of tuples of the form (frontend_name, port_number)
+
"""
app = load_app_from_file(entrypoint_file)
diff --git a/src/lightning/app/launcher/lightning_backend.py b/src/lightning/app/launcher/lightning_backend.py
index 1e3c096e45cf1..fc100b9ead338 100644
--- a/src/lightning/app/launcher/lightning_backend.py
+++ b/src/lightning/app/launcher/lightning_backend.py
@@ -252,6 +252,7 @@ def update_work_statuses(self, works: List[LightningWork]) -> None:
Normally, the Lightning frameworks communicates statuses through the queues, but while the Work instance is
being provisionied, the queues don't exist yet and hence we need to make API calls directly to the backend to
fetch the status and update it in the states.
+
"""
if not works:
return
@@ -305,6 +306,7 @@ def stop_all_works(self, works: List[LightningWork]) -> None:
"""Stop resources for all LightningWorks in this app.
The Works are stopped rather than deleted so that they can be inspected for debugging.
+
"""
cloud_works = self._get_cloud_work_specs(self.client)
diff --git a/src/lightning/app/plugin/plugin.py b/src/lightning/app/plugin/plugin.py
index db66a4ad245ba..67d15d0c27a86 100644
--- a/src/lightning/app/plugin/plugin.py
+++ b/src/lightning/app/plugin/plugin.py
@@ -60,6 +60,7 @@ def run_job(self, name: str, app_entrypoint: str, env_vars: Dict[str, str] = {})
Returns:
The relative URL of the created job.
+
"""
from lightning.app.runners.backends.cloud import CloudBackend
from lightning.app.runners.cloud import CloudRuntime
diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py
index 0101b7a7208ea..d742c9d459160 100644
--- a/src/lightning/app/runners/cloud.py
+++ b/src/lightning/app/runners/cloud.py
@@ -198,8 +198,8 @@ def cloudspace_dispatch(
cluster_id: str,
source_app: Optional[str] = None,
) -> str:
- """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
- such as the project and cluster IDs that are instead passed directly.
+ """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such
+ as the project and cluster IDs that are instead passed directly.
Args:
project_id: The ID of the project.
@@ -214,6 +214,7 @@ def cloudspace_dispatch(
Returns:
The URL of the created job.
+
"""
# Dispatch in four phases: resolution, validation, spec creation, API transactions
# Resolution
@@ -432,6 +433,7 @@ def _resolve_config(self, name: Optional[str], load: bool = True) -> AppConfig:
"""Find and load the config file if it exists (otherwise create an empty config).
Override the name if provided.
+
"""
config_file = _get_config_file(self.entrypoint)
cloudspace_config = AppConfig.load_from_file(config_file) if config_file.exists() and load else AppConfig()
@@ -611,6 +613,7 @@ def _resolve_needs_credits(project: V1Membership):
"""Check if the user likely needs credits to run the app with its hardware.
Returns False if user has 1 or more credits.
+
"""
balance = project.balance
if balance is None:
@@ -698,8 +701,8 @@ def _validate_mounts(self) -> None:
raise RuntimeError(f"Unknown mount protocol `{mount.protocol}` for work `{work.name}`.")
def _get_flow_servers(self) -> List[V1Flowserver]:
- """Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs
- to start servers."""
+ """Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs to
+ start servers."""
flow_servers: List[V1Flowserver] = []
for flow_name in self.app.frontends:
flow_server = V1Flowserver(name=flow_name)
@@ -889,8 +892,7 @@ def _get_auth(credentials: str) -> Optional[V1LightningAuth]:
def _get_env_vars(
env_vars: Dict[str, str], secrets: Dict[str, str], run_app_comment_commands: bool
) -> List[V1EnvVar]:
- """Generate the list of environment variable specs for the app, including variables set by the
- framework."""
+ """Generate the list of environment variable specs for the app, including variables set by the framework."""
v1_env_vars = [V1EnvVar(name=k, value=v) for k, v in env_vars.items()]
if len(secrets.values()) > 0:
@@ -929,6 +931,7 @@ def _api_create_cloudspace_if_not_exists(
"""Create the cloudspace if it doesn't exist.
Return the cloudspace ID.
+
"""
if existing_cloudspace is None:
cloudspace_body = ProjectIdCloudspacesBody(name=name, can_download_source_code=True)
@@ -980,6 +983,7 @@ def _api_transfer_run_instance(
"""Transfer an existing instance to the given run ID and update its specification.
Return the instance.
+
"""
run_instance = self.backend.client.lightningapp_instance_service_update_lightningapp_instance_release(
project_id=project_id,
diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py
index 93f091f8705ab..2db62668937fc 100644
--- a/src/lightning/app/runners/multiprocess.py
+++ b/src/lightning/app/runners/multiprocess.py
@@ -39,6 +39,7 @@ class MultiProcessRuntime(Runtime):
The MultiProcessRuntime will generate 1 process for each :class:`~lightning.app.core.work.LightningWork` and attach
queues to enable communication between the different processes.
+
"""
backend: Union[str, Backend] = "multiprocessing"
diff --git a/src/lightning/app/runners/runtime.py b/src/lightning/app/runners/runtime.py
index 375d1d16a57b0..d710cae985a43 100644
--- a/src/lightning/app/runners/runtime.py
+++ b/src/lightning/app/runners/runtime.py
@@ -66,6 +66,7 @@ def dispatch(
run_app_comment_commands: whether to parse commands from the entrypoint file and execute them before app startup
enable_basic_auth: whether to enable basic authentication for the app
(use credentials in the format username:password as an argument)
+
"""
from lightning.app.runners.runtime_type import RuntimeType
from lightning.app.utilities.component import _set_flow_context
diff --git a/src/lightning/app/source_code/copytree.py b/src/lightning/app/source_code/copytree.py
index 592d3d7deec94..ba51af98affc4 100644
--- a/src/lightning/app/source_code/copytree.py
+++ b/src/lightning/app/source_code/copytree.py
@@ -34,9 +34,9 @@ def _copytree(
dirs_exist_ok=False,
dry_run=False,
) -> List[str]:
- """Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like
- `git` does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks.
- Differences between original and this function are.
+ """Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like `git`
+ does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks. Differences
+ between original and this function are.
1. It supports a list of ignore function instead of a single one in the
original. We can use this for filtering out files based on nested
@@ -66,6 +66,7 @@ def _copytree(
If exception(s) occur, an Error is raised with a list of reasons.
+
"""
files_copied = []
@@ -146,8 +147,8 @@ def _parse_lightningignore(lines: Tuple[str]) -> Set[str]:
def _read_lightningignore(path: Path) -> Set[str]:
- """Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's
- done to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path.
+ """Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's done
+ to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path.
Parameters
----------
@@ -158,6 +159,7 @@ def _read_lightningignore(path: Path) -> Set[str]:
-------
Set[str]
Set of unique lines.
+
"""
raw_lines = path.open().readlines()
return _parse_lightningignore(raw_lines)
diff --git a/src/lightning/app/source_code/hashing.py b/src/lightning/app/source_code/hashing.py
index 362d32f2592b7..6cd823e9a074a 100644
--- a/src/lightning/app/source_code/hashing.py
+++ b/src/lightning/app/source_code/hashing.py
@@ -33,6 +33,7 @@ def _get_hash(files: List[str], algorithm: str = "blake2", chunk_num_blocks: int
----------
[1] https://crypto.stackexchange.com/questions/70101/blake2-vs-md5-for-checksum-file-integrity
[2] https://stackoverflow.com/questions/1131220/get-md5-hash-of-big-files-in-python
+
"""
# validate input
if algorithm == "blake2":
diff --git a/src/lightning/app/source_code/local.py b/src/lightning/app/source_code/local.py
index b01e0c4c0a885..bfad7eb442a1d 100644
--- a/src/lightning/app/source_code/local.py
+++ b/src/lightning/app/source_code/local.py
@@ -133,6 +133,7 @@ def upload(self, url: str) -> None:
packaged repository files which have a size > 2GB.
This limitation should be removed during the datastore upload redesign
+
"""
if self.package_path.stat().st_size > 2e9:
raise OSError(
diff --git a/src/lightning/app/source_code/tar.py b/src/lightning/app/source_code/tar.py
index 7ca93c798b8b6..c3aca1ae317c7 100644
--- a/src/lightning/app/source_code/tar.py
+++ b/src/lightning/app/source_code/tar.py
@@ -36,6 +36,7 @@ def _get_dir_size_and_count(source_dir: str, prefix: Optional[str] = None) -> Tu
-------
Tuple[int, int]
Size in megabytes and file count
+
"""
size = 0
count = 0
@@ -61,6 +62,7 @@ class _TarResults:
The total size of the original directory files in bytes
after_size: int
The total size of the compressed and tarred split files in bytes
+
"""
before_size: int
@@ -70,13 +72,12 @@ class _TarResults:
def _get_split_size(
total_size: int, minimum_split_size: int = 1024 * 1000 * 20, max_split_count: int = MAX_SPLIT_COUNT
) -> int:
- """Calculate the split size we should use to split the multipart upload of an object to a bucket. We are
- limited to 1000 max parts as the way we are using ListMultipartUploads. More info
- https://github.com/gridai/grid/pull/5267
+ """Calculate the split size we should use to split the multipart upload of an object to a bucket. We are limited
+ to 1000 max parts as the way we are using ListMultipartUploads. More info https://github.com/gridai/grid/pull/5267
https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpu-process
https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html
- https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31
- bytes for a single file upload.
+ https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 bytes
+ for a single file upload.
Parameters
----------
@@ -91,6 +92,7 @@ def _get_split_size(
-------
int
Split size
+
"""
max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above
if total_size > max_size:
@@ -123,6 +125,7 @@ def _tar_path(source_path: str, target_file: str, compression: bool = False) ->
-------
TarResults
Results that holds file counts and sizes
+
"""
if os.path.isdir(source_path):
before_size, _ = _get_dir_size_and_count(source_path)
@@ -149,6 +152,7 @@ def _tar_path_python(source_path: str, target_file: str, compression: bool = Fal
Target tar file
compression: bool, default False
Enable compression, which is disabled by default.
+
"""
file_mode = "w:gz" if compression else "w:"
@@ -172,6 +176,7 @@ def _tar_path_subprocess(source_path: str, target_file: str, compression: bool =
Target tar file
compression: bool, default False
Enable compression, which is disabled by default.
+
"""
# Only add compression when users explicitly request it.
# We do this because it takes too long to compress
diff --git a/src/lightning/app/source_code/uploader.py b/src/lightning/app/source_code/uploader.py
index 306ee96d7eda5..82336c7b0b96f 100644
--- a/src/lightning/app/source_code/uploader.py
+++ b/src/lightning/app/source_code/uploader.py
@@ -35,6 +35,7 @@ class FileUploader:
Size of all files to upload
name: str
Name of this upload to display progress
+
"""
workers: int = 8
@@ -72,6 +73,7 @@ def upload_data(self, url: str, data: bytes, retries: int, disconnect_retry_wait
-------
str
ETag from response
+
"""
disconnect_retries = retries
while disconnect_retries > 0:
diff --git a/src/lightning/app/storage/copier.py b/src/lightning/app/storage/copier.py
index 144c2335a72db..3619654cfb825 100644
--- a/src/lightning/app/storage/copier.py
+++ b/src/lightning/app/storage/copier.py
@@ -52,6 +52,7 @@ class _Copier(Thread):
will send requests to this queue.
copy_response_queue: A queue connecting the central StorageOrchestrator with the Copier. The Copier
will send a response to this queue whenever a requested copy has finished.
+
"""
def __init__(
@@ -116,6 +117,7 @@ def _copy_files(
interpreted as a folder as well. If the source is a file, the destination path is interpreted as a file too.
Files in a folder are copied recursively and efficiently using multiple threads.
+
"""
if fs is None:
fs = _filesystem()
diff --git a/src/lightning/app/storage/drive.py b/src/lightning/app/storage/drive.py
index 0cdb31004696b..d90c5a0d41137 100644
--- a/src/lightning/app/storage/drive.py
+++ b/src/lightning/app/storage/drive.py
@@ -46,6 +46,7 @@ def __init__(
component_name: The component name which owns this drive.
When not provided, it is automatically inferred by Lightning.
root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
+
"""
if id.startswith("s3://"):
raise ValueError(
@@ -96,6 +97,7 @@ def put(self, path: str) -> None:
Arguments:
path: The relative path to your files to be added to the Drive.
+
"""
if not self.component_name:
raise Exception("The component name needs to be known to put a path to the Drive.")
@@ -121,6 +123,7 @@ def list(self, path: Optional[str] = ".", component_name: Optional[str] = None)
path: The relative path you want to list files from the Drive.
component_name: By default, the Drive lists files across all components.
If you provide a component name, the listing is specific to this component.
+
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to list files from a Drive.")
@@ -165,6 +168,7 @@ def get(
If you provide a component name, the matching is specific to this component.
timeout: Whether to wait for the files to be available if not created yet.
overwrite: Whether to override the provided path if it exists.
+
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to get files from a Drive.")
@@ -207,11 +211,12 @@ def get(
self._get(self.fs, match, pathlib.Path(os.path.join(self.root_folder, path)).resolve(), overwrite=overwrite)
def delete(self, path: str) -> None:
- """This method enables to delete files under the provided path from the Drive in a blocking fashion. Only
- the component which added a file can delete them.
+ """This method enables to delete files under the provided path from the Drive in a blocking fashion. Only the
+ component which added a file can delete them.
Arguments:
path: The relative path you want to delete files from the Drive.
+
"""
if not self.component_name:
raise Exception("The component name needs to be known to delete a path to the Drive.")
diff --git a/src/lightning/app/storage/filesystem.py b/src/lightning/app/storage/filesystem.py
index 141a29e8a1bc5..943a6a750bd2b 100644
--- a/src/lightning/app/storage/filesystem.py
+++ b/src/lightning/app/storage/filesystem.py
@@ -42,6 +42,7 @@ def put(self, src_path: str, dst_path: str, put_fn: Callable = _copy_files) -> N
src_path: The path to your files locally
dst_path: The path to your files transfered in the shared storage.
put_fn: The method to use to put files in the shared storage.
+
"""
if not os.path.exists(Path(src_path).resolve()):
raise FileExistsError(f"The provided path {src_path} doesn't exist")
@@ -66,6 +67,7 @@ def get(self, src_path: str, dst_path: str, overwrite: bool = True, get_fn: Call
src_path: The path to your files in the shared storage
dst_path: The path to your files transfered locally
get_fn: The method to use to put files in the shared storage.
+
"""
if not src_path.startswith("/"):
raise Exception(f"The provided destination {src_path} needs to start with `/`.")
@@ -80,6 +82,7 @@ def listdir(self, path: str) -> List[str]:
Arguments:
path: The path to files to list.
+
"""
if not path.startswith("/"):
raise Exception(f"The provided destination {path} needs to start with `/`.")
@@ -104,6 +107,7 @@ def walk(self, path: str) -> List[str]:
Arguments:
path: The path to files to list.
+
"""
if not path.startswith("/"):
raise Exception(f"The provided destination {path} needs to start with `/`.")
diff --git a/src/lightning/app/storage/mount.py b/src/lightning/app/storage/mount.py
index efe922aa97e0e..8142b4574a8b0 100644
--- a/src/lightning/app/storage/mount.py
+++ b/src/lightning/app/storage/mount.py
@@ -34,6 +34,7 @@ class Mount:
mount_path: An absolute directory path in the work where external data source should
be mounted as a filesystem. This path should not already exist in your codebase.
If not included, then the root_dir will be set to `/data/`
+
"""
source: str = ""
diff --git a/src/lightning/app/storage/orchestrator.py b/src/lightning/app/storage/orchestrator.py
index 406539f4c750b..43ce7b76e5101 100644
--- a/src/lightning/app/storage/orchestrator.py
+++ b/src/lightning/app/storage/orchestrator.py
@@ -47,6 +47,7 @@ class StorageOrchestrator(Thread):
put requests on this queue for the file-transfer thread to complete.
copy_response_queues: A dictionary of Queues where each Queue connects to one Work. The queue is expected to
contain the completion response from the file-transfer thread running in the Work process.
+
"""
def __init__(
diff --git a/src/lightning/app/storage/path.py b/src/lightning/app/storage/path.py
index f0b7ee9560a9d..0ecb79a7f5d44 100644
--- a/src/lightning/app/storage/path.py
+++ b/src/lightning/app/storage/path.py
@@ -53,6 +53,7 @@ class Path(PathlibPath):
Args:
*args: Accepts the same arguments as in :class:`pathlib.Path`
**kwargs: Accepts the same keyword arguments as in :class:`pathlib.Path`
+
"""
@classmethod
@@ -105,6 +106,7 @@ def origin_name(self) -> str:
"""The name of the LightningWork where this path was first created.
Attaching a Path to a LightningWork will automatically make it the `origin`.
+
"""
from lightning.app.core.work import LightningWork
@@ -115,6 +117,7 @@ def consumer_name(self) -> str:
"""The name of the LightningWork where this path is being accessed.
By default, this is the same as the :attr:`origin_name`.
+
"""
from lightning.app.core.work import LightningWork
@@ -125,6 +128,7 @@ def hash(self) -> Optional[str]:
"""The hash of this Path uniquely identifies the file path and the associated origin Work.
Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
+
"""
if self._origin is None:
return None
@@ -152,6 +156,7 @@ def exists(self) -> bool:
If you strictly want to check local existence only, use :meth:`exists_local` instead. If you strictly want
to check existence on the remote (regardless of whether the file exists locally or not), use
:meth:`exists_remote`.
+
"""
return self.exists_local() or (self._origin and self.exists_remote())
@@ -164,6 +169,7 @@ def exists_remote(self) -> bool:
Raises:
RuntimeError: If the path is not attached to any Work (origin undefined).
+
"""
# Fail early if we need to check the remote but an origin is not defined
if not self._origin or self._request_queue is None or self._response_queue is None:
@@ -272,6 +278,7 @@ def _attach_work(self, work: "LightningWork") -> None:
Args:
work: LightningWork to be attached to this Path.
+
"""
if self._origin is None:
# Can become an owner only if there is not already one
@@ -374,11 +381,11 @@ def _is_lit_path(path: Union[str, Path]) -> bool:
def _shared_local_mount_path() -> pathlib.Path:
- """Returns the shared directory through which the Copier threads move files from one Work filesystem to
- another.
+ """Returns the shared directory through which the Copier threads move files from one Work filesystem to another.
The shared directory can be set via the environment variable ``SHARED_MOUNT_DIRECTORY`` and should be pointing to a
directory that all Works have mounted (shared filesystem).
+
"""
path = pathlib.Path(os.environ.get("SHARED_MOUNT_DIRECTORY", ".shared"))
path.mkdir(parents=True, exist_ok=True)
@@ -397,6 +404,7 @@ def _shared_storage_path() -> pathlib.Path:
The shared path gets set by the environment. Locally, it is pointing to a directory determined by the
``SHARED_MOUNT_DIRECTORY`` environment variable. In the cloud, the shared path will point to a S3 bucket. All Works
have access to this shared dropbox.
+
"""
storage_path = os.getenv("LIGHTNING_STORAGE_PATH", "")
if storage_path != "":
diff --git a/src/lightning/app/storage/payload.py b/src/lightning/app/storage/payload.py
index 05d3463488941..255a60c3fe69b 100644
--- a/src/lightning/app/storage/payload.py
+++ b/src/lightning/app/storage/payload.py
@@ -61,6 +61,7 @@ def hash(self) -> Optional[str]:
"""The hash of this Payload uniquely identifies the payload and the associated origin Work.
Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
+
"""
if self._origin is None:
return None
@@ -72,6 +73,7 @@ def origin_name(self) -> str:
"""The name of the LightningWork where this payload was first created.
Attaching a Payload to a LightningWork will automatically make it the `origin`.
+
"""
from lightning.app.core.work import LightningWork
@@ -82,6 +84,7 @@ def consumer_name(self) -> str:
"""The name of the LightningWork where this payload is being accessed.
By default, this is the same as the :attr:`origin_name`.
+
"""
from lightning.app.core.work import LightningWork
@@ -107,6 +110,7 @@ def _attach_work(self, work: "LightningWork") -> None:
Args:
work: LightningWork to be attached to this Payload.
+
"""
if self._origin is None:
# Can become an owner only if there is not already one
@@ -130,6 +134,7 @@ def exists_remote(self):
Raises:
RuntimeError: If the payload is not attached to any Work (origin undefined).
+
"""
# Fail early if we need to check the remote but an origin is not defined
if not self._origin or self._request_queue is None or self._response_queue is None:
diff --git a/src/lightning/app/testing/helpers.py b/src/lightning/app/testing/helpers.py
index 99a81b2523fcf..c1eafce723e68 100644
--- a/src/lightning/app/testing/helpers.py
+++ b/src/lightning/app/testing/helpers.py
@@ -62,6 +62,7 @@ class _RunIf:
@pytest.mark.parametrize("arg1", [1, 2.0])
def test_wrapper(arg1):
assert arg1 > 0.0
+
"""
def __new__(
@@ -155,6 +156,7 @@ class EmptyFlow(LightningFlow):
"""A LightningFlow that implements all abstract methods to do nothing.
Useful for mocking in tests.
+
"""
def run(self):
@@ -165,6 +167,7 @@ class EmptyWork(LightningWork):
"""A LightningWork that implements all abstract methods to do nothing.
Useful for mocking in tests.
+
"""
def run(self):
diff --git a/src/lightning/app/testing/testing.py b/src/lightning/app/testing/testing.py
index 93a6be199794e..f873355ab0686 100644
--- a/src/lightning/app/testing/testing.py
+++ b/src/lightning/app/testing/testing.py
@@ -503,6 +503,7 @@ def delete_cloud_lightning_apps(name=None):
"""Cleanup cloud apps that start with the name test-{PR_NUMBER}-{TEST_APP_NAME}.
PR_NUMBER and TEST_APP_NAME are environment variables.
+
"""
client = LightningClient()
diff --git a/src/lightning/app/utilities/app_commands.py b/src/lightning/app/utilities/app_commands.py
index 2db33d9ebbadb..e3e8af50d3e23 100644
--- a/src/lightning/app/utilities/app_commands.py
+++ b/src/lightning/app/utilities/app_commands.py
@@ -44,6 +44,7 @@ def _extract_commands_from_file(file_name: str) -> CommandLines:
"""Extract all lines at the top of the file which contain commands to execute.
The return struct contains a list of commands to execute with the corresponding line number the command executed on.
+
"""
cl = CommandLines(
file=file_name,
@@ -83,6 +84,7 @@ def _execute_app_commands(cl: CommandLines) -> None:
"""Open a subprocess shell to execute app commands.
The calling app environment is used in the current environment the code is running in
+
"""
for command, line_number in zip(cl.commands, cl.line_numbers):
logger.info(f"Running app setup command: {command}")
@@ -116,6 +118,7 @@ def run_app_commands(file: str) -> None:
foo! bar <--- not a command import lightning <--- not a command, end parsing.
where `echo "hello world" && pip install foo` would be executed in the current running environment.
+
"""
cl = _extract_commands_from_file(file_name=file)
if len(cl.commands) == 0:
diff --git a/src/lightning/app/utilities/app_helpers.py b/src/lightning/app/utilities/app_helpers.py
index 018de7ab32d1f..7878fac4e3f2a 100644
--- a/src/lightning/app/utilities/app_helpers.py
+++ b/src/lightning/app/utilities/app_helpers.py
@@ -57,8 +57,7 @@ class StateEntry:
class StateStore(ABC):
- """Base class of State store that provides simple key, value store to keep track of app state, served app
- state."""
+ """Base class of State store that provides simple key, value store to keep track of app state, served app state."""
@abstractmethod
def __init__(self):
@@ -352,6 +351,7 @@ def _walk_to_component(
"""Returns a generator that runs through the tree starting from the root down to the given component.
At each node, yields parent and child as a tuple.
+
"""
from lightning.app.structures import Dict, List
@@ -469,6 +469,7 @@ def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict:
root_flow: The flow at the top of the component tree.
state: The collected state dict.
strict: Whether to validate all components have been re-created.
+
"""
# 1: Reload the state of the existing works
for w in root_flow.works():
diff --git a/src/lightning/app/utilities/app_logs.py b/src/lightning/app/utilities/app_logs.py
index 903bc615bad6a..446418f9b18e2 100644
--- a/src/lightning/app/utilities/app_logs.py
+++ b/src/lightning/app/utilities/app_logs.py
@@ -49,6 +49,7 @@ def _push_log_events_to_read_queue_callback(component_name: str, read_queue: que
"""Pushes _LogEvents from websocket to read_queue.
Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
+
"""
def callback(ws_app: WebSocketApp, msg: str):
diff --git a/src/lightning/app/utilities/auth.py b/src/lightning/app/utilities/auth.py
index b29801c9fdb36..2ccb2d068109f 100644
--- a/src/lightning/app/utilities/auth.py
+++ b/src/lightning/app/utilities/auth.py
@@ -42,6 +42,7 @@ def _credential_string_to_basic_auth_params(credential_string: str) -> Dict[str,
"""Returns the name/ID pair for each given Secret name.
Raises a `ValueError` if any of the given Secret names do not exist.
+
"""
if credential_string.count(":") != 1:
raise ValueError(
diff --git a/src/lightning/app/utilities/cli_helpers.py b/src/lightning/app/utilities/cli_helpers.py
index 6280fdecad86a..ad9c835ab46ec 100644
--- a/src/lightning/app/utilities/cli_helpers.py
+++ b/src/lightning/app/utilities/cli_helpers.py
@@ -129,6 +129,7 @@ def __init__(
Arguments:
app_id_or_name_or_url: An identified for the app.
use_cache: Whether to load the openapi spec from the cache.
+
"""
self.app_id_or_name_or_url = app_id_or_name_or_url
self.url = None
diff --git a/src/lightning/app/utilities/clusters.py b/src/lightning/app/utilities/clusters.py
index a083e41c7110b..663ba66d456e9 100644
--- a/src/lightning/app/utilities/clusters.py
+++ b/src/lightning/app/utilities/clusters.py
@@ -25,6 +25,7 @@ def _get_default_cluster(client: LightningClient, project_id: str) -> str:
"""This utility implements a minimal version of the cluster selection logic used in the cloud.
TODO: This should be requested directly from the platform.
+
"""
cluster_bindings = client.projects_service_list_project_cluster_bindings(project_id=project_id).clusters
diff --git a/src/lightning/app/utilities/component.py b/src/lightning/app/utilities/component.py
index 75c4c09d5671b..bb4bc5ecfd606 100644
--- a/src/lightning/app/utilities/component.py
+++ b/src/lightning/app/utilities/component.py
@@ -36,6 +36,7 @@ def _convert_paths_after_init(root: "LightningFlow"):
This is necessary because at the time of instantiating the component, its full affiliation is not known and Paths
that get passed to other componenets during ``__init__`` are otherwise not able to reference their origin or
consumer.
+
"""
from lightning.app import LightningFlow, LightningWork
from lightning.app.storage import Path
@@ -52,6 +53,7 @@ def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
"""Utility function to sanitize the state of a component.
Sanitization enables the state to be deep-copied and hashed.
+
"""
from lightning.app.storage import Drive, Path
from lightning.app.storage.payload import _BasePayload
@@ -132,6 +134,7 @@ def _context(ctx: str) -> Generator[None, None, None]:
The context is used to determine whether the current process is running for a LightningFlow or for a LightningWork.
See also :func:`_get_context`, :func:`_set_context`. For internal use only.
+
"""
prev = _get_context()
_set_context(ctx)
diff --git a/src/lightning/app/utilities/data_structures.py b/src/lightning/app/utilities/data_structures.py
index 626e381d736b0..495c43fd0ea24 100644
--- a/src/lightning/app/utilities/data_structures.py
+++ b/src/lightning/app/utilities/data_structures.py
@@ -29,6 +29,7 @@ class AttributeDict(Dict):
"key2": abc
"my-key": 3.14
"new_key": 42
+
"""
def __getattr__(self, key: str) -> Optional[Any]:
diff --git a/src/lightning/app/utilities/exceptions.py b/src/lightning/app/utilities/exceptions.py
index fb63a6ce7bea7..3bb5eced46ed7 100644
--- a/src/lightning/app/utilities/exceptions.py
+++ b/src/lightning/app/utilities/exceptions.py
@@ -29,6 +29,7 @@ class _ApiExceptionHandler(Group):
However, if the ApiException cannot be decoded, or is not
a 4xx error, the original ApiException will be re-raised.
+
"""
def invoke(self, ctx: Context) -> Any:
@@ -81,6 +82,7 @@ class LightningPlatformException(Exception): # pragma: no cover
It gets raised by the Lightning Launcher on the platform side when the app is running in the cloud, and is useful
when framework or user code needs to catch exceptions specific to the platform, e.g., when resources exceed quotas.
+
"""
diff --git a/src/lightning/app/utilities/git.py b/src/lightning/app/utilities/git.py
index aa3294aa19b8f..1293a2a5095fa 100644
--- a/src/lightning/app/utilities/git.py
+++ b/src/lightning/app/utilities/git.py
@@ -26,6 +26,7 @@ def execute_git_command(args: List[str], cwd=None) -> str:
-------
output: str
String combining stdout and stderr.
+
"""
process = subprocess.run(["git"] + args, capture_output=True, text=True, cwd=cwd, check=False)
return process.stdout.strip() + process.stderr.strip()
@@ -61,6 +62,7 @@ def check_if_remote_head_is_different() -> Union[bool, None]:
This only compares the local SHA to the HEAD commit of a given branch. This check won't be used if user isn't in a
HEAD locally.
+
"""
# Check SHA values.
local_sha = execute_git_command(["rev-parse", "@"])
@@ -78,6 +80,7 @@ def has_uncommitted_files() -> bool:
"""Checks if user has uncommited files in local repository.
If there are uncommited files, then show a prompt indicating that uncommited files exist locally.
+
"""
files = execute_git_command(["update-index", "--refresh"])
return bool(files)
diff --git a/src/lightning/app/utilities/imports.py b/src/lightning/app/utilities/imports.py
index 7c4c43a3c2ddf..092917660fa46 100644
--- a/src/lightning/app/utilities/imports.py
+++ b/src/lightning/app/utilities/imports.py
@@ -32,6 +32,7 @@ def _get_extras(extras: str) -> str:
"""Get the given extras as a space delimited string.
Used by the platform to install cloud extras in the cloud.
+
"""
from lightning.app import __package_name__
diff --git a/src/lightning/app/utilities/introspection.py b/src/lightning/app/utilities/introspection.py
index 394c5da593a19..87200576e6936 100644
--- a/src/lightning/app/utilities/introspection.py
+++ b/src/lightning/app/utilities/introspection.py
@@ -22,13 +22,14 @@
class LightningVisitor(ast.NodeVisitor):
- """Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected
- to define class_name and implement the analyze_class_def method.
+ """Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected to
+ define class_name and implement the analyze_class_def method.
Attributes
----------
class_name: str
Name of class to identify, to be defined in subclasses.
+
"""
class_name: Optional[str] = None
@@ -63,6 +64,7 @@ class LightningModuleVisitor(LightningVisitor):
Names of methods that are part of the LightningModule API.
hooks: Set[str]
Names of hooks that are part of the LightningModule API.
+
"""
class_name: Optional[str] = "LightningModule"
@@ -132,6 +134,7 @@ class LightningDataModuleVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the LightningDataModule API.
+
"""
class_name = "LightningDataModule"
@@ -155,6 +158,7 @@ class LightningLoggerVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the Logger API.
+
"""
class_name = "Logger"
@@ -171,6 +175,7 @@ class LightningCallbackVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the Logger API.
+
"""
class_name = "Callback"
@@ -223,6 +228,7 @@ class LightningStrategyVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the Logger API.
+
"""
class_name = "Strategy"
@@ -282,6 +288,7 @@ class Scanner:
glob_pattern: str
Glob pattern to use when looking for files in the path,
applied when path is a directory. Default is "**/*.py".
+
"""
# TODO: Finalize introspecting the methods from all the discovered methods.
@@ -341,6 +348,7 @@ def scan(self) -> List[Dict[str, str]]:
List[Dict[str, Any]]
List of dicts containing all metadata required
to import modules found.
+
"""
modules_found: Dict[str, List[Dict[str, Any]]] = {}
diff --git a/src/lightning/app/utilities/layout.py b/src/lightning/app/utilities/layout.py
index 553bcd6e91fd6..e1a61539a0fdd 100644
--- a/src/lightning/app/utilities/layout.py
+++ b/src/lightning/app/utilities/layout.py
@@ -26,6 +26,7 @@ def _add_comment_to_literal_code(method, contains, comment):
"""Inspects a method's code and adds a message to it.
This is a nice to have, so if it fails for some reason, it shouldn't affect the program.
+
"""
try:
lines = inspect.getsource(method)
diff --git a/src/lightning/app/utilities/load_app.py b/src/lightning/app/utilities/load_app.py
index f22aafa3ceabb..4504704c3d2e5 100644
--- a/src/lightning/app/utilities/load_app.py
+++ b/src/lightning/app/utilities/load_app.py
@@ -61,6 +61,7 @@ def _load_objects_from_file(
raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit.
mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to
be loaded without installing dependencies.
+
"""
# Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313
@@ -110,6 +111,7 @@ def load_app_from_file(
Arguments:
filepath: The path to the file containing the LightningApp.
raise_exception: If True, raise an exception if the app cannot be loaded.
+
"""
from lightning.app.core.app import LightningApp
@@ -142,6 +144,7 @@ def open_python_file(filename):
In Python 3, we would like all files to be opened with utf-8 encoding. However, some author like to specify PEP263
headers in their source files with their own encodings. In that case, we should respect the author's encoding.
+
"""
import tokenize
@@ -204,6 +207,7 @@ def _patch_sys_path(append):
Args:
append: The value to append to the path.
+
"""
if append in sys.path:
yield
diff --git a/src/lightning/app/utilities/login.py b/src/lightning/app/utilities/login.py
index bc2d5d713c4bc..3db7d1cb3b7bc 100644
--- a/src/lightning/app/utilities/login.py
+++ b/src/lightning/app/utilities/login.py
@@ -60,6 +60,7 @@ def load(self) -> bool:
Returns
----------
True if credentials are available.
+
"""
if not self.secrets_file.exists():
logger.debug("Credentials file not found.")
@@ -117,6 +118,7 @@ def authenticate(self) -> Optional[str]:
Returns
----------
authorization header to use when authentication completes.
+
"""
if not self.load():
# First try to authenticate from env
diff --git a/src/lightning/app/utilities/logs_socket_api.py b/src/lightning/app/utilities/logs_socket_api.py
index 98d95bfa965a4..8bc8fa47d1812 100644
--- a/src/lightning/app/utilities/logs_socket_api.py
+++ b/src/lightning/app/utilities/logs_socket_api.py
@@ -79,6 +79,7 @@ def print_log_msg(ws_app, msg):
Returns:
WebSocketApp of the wanted socket
+
"""
_token = self._get_api_token()
clean_ws_host = urlparse(self.api_client.configuration.host).netloc
diff --git a/src/lightning/app/utilities/name_generator.py b/src/lightning/app/utilities/name_generator.py
index 28c43c241ce8e..c57a65f63a3d8 100644
--- a/src/lightning/app/utilities/name_generator.py
+++ b/src/lightning/app/utilities/name_generator.py
@@ -1353,6 +1353,7 @@ def get_unique_name():
'meek-ardinghelli-4506'
>>> get_unique_name()
'truthful-dijkstra-2286'
+
"""
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) # noqa: S311
return f"{adjective}-{surname}-{i}"
diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py
index 3d80479c3fe53..314631b5593f8 100644
--- a/src/lightning/app/utilities/network.py
+++ b/src/lightning/app/utilities/network.py
@@ -95,6 +95,7 @@ def _configure_session() -> Session:
"""Configures the session for GET and POST requests.
It enables a generous retrial strategy that waits for the application server to connect.
+
"""
retry_strategy = Retry(
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
@@ -124,10 +125,10 @@ def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> floa
def _retry_wrapper(self, func: Callable, max_tries: Optional[int] = None) -> Callable:
- """Returns the function decorated by a wrapper that retries the call several times if a connection error
- occurs.
+ """Returns the function decorated by a wrapper that retries the call several times if a connection error occurs.
The retries follow an exponential backoff.
+
"""
@wraps(func)
@@ -175,6 +176,7 @@ class LightningClient(GridRestClient):
Args:
retry: Whether API calls should follow a retry mechanism with exponential backoff.
max_tries: Maximum number of attempts (or -1 to retry forever).
+
"""
def __init__(self, retry: bool = True, max_tries: Optional[int] = None) -> None:
@@ -275,5 +277,6 @@ def log_function(self, message: str, *args, **kwargs: Any):
We enabled customisation here instead of just using `logger.debug` because HTTP logging can be very noisy, but
it is crucial for finding bugs when we have them
+
"""
pass
diff --git a/src/lightning/app/utilities/openapi.py b/src/lightning/app/utilities/openapi.py
index c79501bc6b498..f210c3cd47b04 100644
--- a/src/lightning/app/utilities/openapi.py
+++ b/src/lightning/app/utilities/openapi.py
@@ -49,6 +49,7 @@ def create_openapi_object(json_obj: Dict, target: Any):
Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid
object.
+
"""
if not isinstance(json_obj, dict):
raise TypeError("json_obj must be a dictionary")
diff --git a/src/lightning/app/utilities/packaging/app_config.py b/src/lightning/app/utilities/packaging/app_config.py
index f22ffa99d186e..57177344a8fe0 100644
--- a/src/lightning/app/utilities/packaging/app_config.py
+++ b/src/lightning/app/utilities/packaging/app_config.py
@@ -29,6 +29,7 @@ class AppConfig:
Args:
name: Optional name of the application. If not provided, auto-generates a new name.
+
"""
name: str = field(default_factory=get_unique_name)
@@ -56,6 +57,7 @@ def load_from_dir(cls, directory: Union[str, pathlib.Path]) -> "AppConfig":
Args:
directory: Path to a folder which contains the '.lightning' config file to load.
+
"""
return cls.load_from_file(pathlib.Path(directory, _APP_CONFIG_FILENAME))
@@ -65,6 +67,7 @@ def _get_config_file(source_path: Union[str, pathlib.Path]) -> pathlib.Path:
Args:
source_path: A path to a folder or a file.
+
"""
source_path = pathlib.Path(source_path).absolute()
if source_path.is_file():
diff --git a/src/lightning/app/utilities/packaging/build_config.py b/src/lightning/app/utilities/packaging/build_config.py
index fb247309d8b65..8a580da71d3f4 100644
--- a/src/lightning/app/utilities/packaging/build_config.py
+++ b/src/lightning/app/utilities/packaging/build_config.py
@@ -81,6 +81,7 @@ class BuildConfig:
image: The base image that the work runs on. This should be a publicly accessible image from a registry that
doesn't enforce rate limits (such as DockerHub) to pull this image, otherwise your application will not
start.
+
"""
requirements: List[str] = field(default_factory=list)
@@ -111,6 +112,7 @@ def build_commands(self):
return ["apt-get install libsparsehash-dev"]
BuildConfig(requirements=["git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0"])
+
"""
return []
diff --git a/src/lightning/app/utilities/packaging/cloud_compute.py b/src/lightning/app/utilities/packaging/cloud_compute.py
index 58ac06afdbe65..246c04b148863 100644
--- a/src/lightning/app/utilities/packaging/cloud_compute.py
+++ b/src/lightning/app/utilities/packaging/cloud_compute.py
@@ -88,6 +88,7 @@ class CloudCompute:
interruptible: Whether to run on a interruptible machine e.g the machine can be stopped
at any time by the providers. This is also known as spot or preemptible machines.
Compared to on-demand machines, they tend to be cheaper.
+
"""
name: str = "default"
diff --git a/src/lightning/app/utilities/packaging/lightning_utils.py b/src/lightning/app/utilities/packaging/lightning_utils.py
index 3852c941ed676..c49a6d2d88031 100644
--- a/src/lightning/app/utilities/packaging/lightning_utils.py
+++ b/src/lightning/app/utilities/packaging/lightning_utils.py
@@ -108,10 +108,11 @@ def get_dist_path_if_editable_install(project_name) -> str:
def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = "lightning") -> Optional[Callable]:
- """This function determines if lightning is installed in editable mode (for developers) and packages the
- current lightning source along with the app.
+ """This function determines if lightning is installed in editable mode (for developers) and packages the current
+ lightning source along with the app.
For normal users who install via PyPi or Conda, then this function does not do anything.
+
"""
if not get_dist_path_if_editable_install(package_name):
return None
diff --git a/src/lightning/app/utilities/proxies.py b/src/lightning/app/utilities/proxies.py
index 39d33785068ac..9ef301f31fc77 100644
--- a/src/lightning/app/utilities/proxies.py
+++ b/src/lightning/app/utilities/proxies.py
@@ -152,6 +152,7 @@ def _validate_call_args(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) ->
Currently, this performs a check against strings that look like filesystem paths and may need to be wrapped with
a Lightning Path by the user.
+
"""
def warn_if_pathlike(obj: Union[os.PathLike, str]):
@@ -172,8 +173,8 @@ def warn_if_pathlike(obj: Union[os.PathLike, str]):
@staticmethod
def _process_call_args(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
- """Processes all positional and keyword arguments before they get passed to the caller queue and sent to
- the LightningWork.
+ """Processes all positional and keyword arguments before they get passed to the caller queue and sent to the
+ LightningWork.
Currently, this method only applies sanitization to Lightning Path objects.
@@ -183,6 +184,7 @@ def _process_call_args(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[T
Returns:
The positional and keyword arguments in the same order they were passed in.
+
"""
def sanitize(obj: Union[Path, Drive]) -> Union[Path, Dict]:
@@ -200,8 +202,8 @@ def sanitize(obj: Union[Path, Drive]) -> Union[Path, Dict]:
@staticmethod
def _convert_hashable(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
- """Processes all positional and keyword arguments before they get passed to the caller queue and sent to
- the LightningWork.
+ """Processes all positional and keyword arguments before they get passed to the caller queue and sent to the
+ LightningWork.
Currently, this method only applies sanitization to Hashable Objects.
@@ -211,6 +213,7 @@ def _convert_hashable(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tu
Returns:
The positional and keyword arguments in the same order they were passed in.
+
"""
from lightning.app.utilities.types import Hashable
@@ -221,9 +224,9 @@ def sanitize(obj: Hashable) -> Union[Path, Dict]:
class WorkStateObserver(Thread):
- """This thread runs alongside LightningWork and periodically checks for state changes. If the state changed
- from one interval to the next, it will compute the delta and add it to the queue which is connected to the
- Flow. This enables state changes to be captured that are not triggered through a setattr call.
+ """This thread runs alongside LightningWork and periodically checks for state changes. If the state changed from
+ one interval to the next, it will compute the delta and add it to the queue which is connected to the Flow. This
+ enables state changes to be captured that are not triggered through a setattr call.
Args:
work: The LightningWork for which the state should be monitored
@@ -238,6 +241,7 @@ class Work(LightningWork):
def run(self):
# This update gets sent to the Flow once the thread compares the new state with the previous one
self.list.append(1)
+
"""
def __init__(
@@ -690,6 +694,7 @@ def persist_artifacts(work: "LightningWork") -> None:
storage.
Files that don't exist or do not originate from the given Work will be skipped.
+
"""
artifact_paths = [getattr(work, name) for name in work._paths]
# only copy files that belong to this Work, i.e., when the path's origin refers to the current Work
diff --git a/src/lightning/app/utilities/secrets.py b/src/lightning/app/utilities/secrets.py
index 9ba758d4d26d6..57347c92a1476 100644
--- a/src/lightning/app/utilities/secrets.py
+++ b/src/lightning/app/utilities/secrets.py
@@ -22,6 +22,7 @@ def _names_to_ids(secret_names: Iterable[str]) -> Dict[str, str]:
"""Returns the name/ID pair for each given Secret name.
Raises a `ValueError` if any of the given Secret names do not exist.
+
"""
lightning_client = LightningClient()
diff --git a/src/lightning/app/utilities/state.py b/src/lightning/app/utilities/state.py
index 014eaa4fdf931..df187776d6b64 100644
--- a/src/lightning/app/utilities/state.py
+++ b/src/lightning/app/utilities/state.py
@@ -93,6 +93,7 @@ def __init__(
my_affiliation: A tuple describing the affiliation this app state represents. When storing a state dict
on this AppState, this affiliation will be used to reduce the scope of the given state.
plugin: A plugin to handle authorization.
+
"""
self._use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
self._host = host or ("http://127.0.0.1" if self._use_localhost else None)
@@ -132,6 +133,7 @@ def _find_state_under_affiliation(state, my_affiliation: Tuple[str, ...]) -> Dic
For example, if the affiliation is ``("root", "subflow")``, then the returned state will be
``state["flows"]["subflow"]``.
+
"""
children_state = state
for name in my_affiliation:
diff --git a/src/lightning/app/utilities/tracer.py b/src/lightning/app/utilities/tracer.py
index c5a0e56b01264..fe44f91947305 100644
--- a/src/lightning/app/utilities/tracer.py
+++ b/src/lightning/app/utilities/tracer.py
@@ -101,6 +101,7 @@ def add_traced(self, cls, method_name, stack_level=1, pre_fn=None, post_fn=None)
Optionally provide two functions that will execute prior to and after the method. The functions also have a
chance to modify the input arguments and the return values of the methods.
+
"""
self.methods.append((cls, method_name, stack_level, pre_fn, post_fn))
@@ -108,6 +109,7 @@ def _instrument(self):
"""Modify classes by wrapping methods that need to be traced.
Initialize the output trace dict.
+
"""
self.res = {}
for cls, method, stack_level, pre_fn, post_fn in self.methods:
@@ -163,6 +165,7 @@ def trace(self, *args: Any, init_globals=None) -> Optional[dict]:
"""Execute the command-line arguments in args after instrumenting for tracing.
Restore the classes to their initial state after tracing.
+
"""
args = list(args)
script = args[0]
diff --git a/src/lightning/app/utilities/tree.py b/src/lightning/app/utilities/tree.py
index 5dafaee6bf60b..69ae7d144b15e 100644
--- a/src/lightning/app/utilities/tree.py
+++ b/src/lightning/app/utilities/tree.py
@@ -26,6 +26,7 @@ def breadth_first(root: "Component", types: Type["ComponentTuple"] = None):
Arguments:
root: The root component of the tree
types: If provided, only the component types in this list will be visited.
+
"""
yield from _BreadthFirstVisitor(root, types)
diff --git a/src/lightning/data/backends.py b/src/lightning/data/backends.py
index 1d4dbcccd451f..5fcb32eb7877c 100644
--- a/src/lightning/data/backends.py
+++ b/src/lightning/data/backends.py
@@ -29,6 +29,7 @@ def get_aws_credentials() -> "RefreshableCredentials":
Returns:
credentials object to be used for file reading
+
"""
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher
diff --git a/src/lightning/data/datasets/base.py b/src/lightning/data/datasets/base.py
index d0ebc163402ad..c8ee89ac96872 100644
--- a/src/lightning/data/datasets/base.py
+++ b/src/lightning/data/datasets/base.py
@@ -11,6 +11,7 @@ class _Dataset(TorchDataset):
Args:
backend: storage location of the data_source. current options are "s3" or "local"
+
"""
def __init__(self, backend: Literal["local", "s3"] = "local"):
@@ -31,6 +32,7 @@ def open(self, file: str, mode: str = "r", kwargs_for_open: Any = {}, **kwargs:
Returns:
A stream object of the file.
+
"""
return OpenCloudFileObj(
path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs
diff --git a/src/lightning/data/datasets/env.py b/src/lightning/data/datasets/env.py
index 4e923f34735f8..51a9f21271e81 100644
--- a/src/lightning/data/datasets/env.py
+++ b/src/lightning/data/datasets/env.py
@@ -10,6 +10,7 @@ class _DistributedEnv:
Args:
world_size: The number of total distributed training processes
global_rank: The rank of the current process within this pool of training processes
+
"""
def __init__(self, world_size: int, global_rank: int):
@@ -24,6 +25,7 @@ def detect(cls) -> "_DistributedEnv":
This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers)
as the distributed framework won't be initialized there.
It will default to 1 distributed process in this case.
+
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
@@ -50,6 +52,7 @@ class _WorkerEnv:
Args:
world_size: The number of dataloader workers for the current training process
rank: The rank of the current worker within the number of workers
+
"""
def __init__(self, world_size: int, rank: int):
@@ -63,6 +66,7 @@ def detect(cls) -> "_WorkerEnv":
Note:
This only works reliably within a dataloader worker as otherwise the necessary information won't be present.
In such a case it will default to 1 worker
+
"""
worker_info = get_worker_info()
num_workers = worker_info.num_workers if worker_info is not None else 1
@@ -83,6 +87,7 @@ class Environment:
Args:
dist_env: The distributed environment (distributed worldsize and global rank)
worker_env: The worker environment (number of workers, worker rank)
+
"""
def __init__(self, dist_env: Optional[_DistributedEnv], worker_env: Optional[_WorkerEnv]):
@@ -105,6 +110,7 @@ def from_args(
num_workers: The number of workers per distributed training process
current_worker_rank: The rank of the current worker within the number of workers of
the current training process
+
"""
dist_env = _DistributedEnv(dist_world_size, global_rank)
worker_env = _WorkerEnv(num_workers, current_worker_rank)
@@ -117,6 +123,7 @@ def num_shards(self) -> int:
Note:
This may not be accurate in a non-dataloader-worker process like the main training process
as it doesn't necessarily know about the number of dataloader workers.
+
"""
assert self.worker_env is not None
assert self.dist_env is not None
@@ -129,6 +136,7 @@ def shard_rank(self) -> int:
Note:
This may not be accurate in a non-dataloader-worker process like the main training process as it
doesn't necessarily know about the number of dataloader workers.
+
"""
assert self.worker_env is not None
assert self.dist_env is not None
diff --git a/src/lightning/data/datasets/index.py b/src/lightning/data/datasets/index.py
index 9a0b81b42f6a8..34ecaaf095cc4 100644
--- a/src/lightning/data/datasets/index.py
+++ b/src/lightning/data/datasets/index.py
@@ -10,6 +10,7 @@ def get_index(s3_connection_path: str, index_file_path: str) -> bool:
Returns:
Returns True is the index got created and False if it wasn't
+
"""
if s3_connection_path.startswith("/data/"):
@@ -82,6 +83,7 @@ def _get_index(data_connection_path: str, index_file_path: str) -> bool:
Returns:
True if the index retrieved
+
"""
PROJECT_ID_ENV = "LCP_ID"
diff --git a/src/lightning/data/datasets/iterable.py b/src/lightning/data/datasets/iterable.py
index f97b16f89e20f..54388138bcf79 100644
--- a/src/lightning/data/datasets/iterable.py
+++ b/src/lightning/data/datasets/iterable.py
@@ -29,6 +29,7 @@ class _Chunk:
chunk_data: The original data contained by this chunk
chunk_size: The number of samples contained in this chunk
start_index: the index from where to start sampling the chunk (already retrieved samples)
+
"""
def __init__(self, chunk_data: Any, chunk_size: int, start_index: int = 0):
@@ -62,8 +63,8 @@ def index_permutations(self) -> Tuple[int, ...]:
class LightningIterableDataset(_StatefulIterableDataset, _Dataset):
- """An iterable dataset that can be resumed mid-epoch, implements chunking and sharding of chunks. The behavior
- of this dataset can be customized with the following hooks:
+ """An iterable dataset that can be resumed mid-epoch, implements chunking and sharding of chunks. The behavior of
+ this dataset can be customized with the following hooks:
- ``prepare_chunk`` gives the possibility to prepare the chunk one iteration before its actually loaded
(e.g. download from s3).
@@ -100,6 +101,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset):
Note:
Order of data is only guaranteed when resuming with the same distributed settings and the same number of
workers. Everything else leads to different sharding and therefore results in different data order.
+
"""
def __init__(
@@ -156,11 +158,12 @@ def __init__(
@abstractmethod
def load_chunk(self, chunk: Any) -> Any:
- """Implement this to load a single chunk into memory. This could e.g. mean loading the file that has
- previously been downloaded from s3.
+ """Implement this to load a single chunk into memory. This could e.g. mean loading the file that has previously
+ been downloaded from s3.
Args:
chunk: The chunk that should be currently loaded
+
"""
@abstractmethod
@@ -171,6 +174,7 @@ def load_sample_from_chunk(self, chunk: Any, index: int) -> Any:
Args:
chunk: The chunk the sample should be retrieved from
index: The index of the current sample to retrieve within the chunk.
+
"""
def prepare_chunk(self, chunk: Any) -> None:
@@ -178,6 +182,7 @@ def prepare_chunk(self, chunk: Any) -> None:
Args:
chunk: the chunk data to prepare.
+
"""
def __iter__(self) -> "LightningIterableDataset":
@@ -185,6 +190,7 @@ def __iter__(self) -> "LightningIterableDataset":
Before that, detects the env if necessary, shuffles chunks, shards the data and shuffles sample orders within
chunks.
+
"""
self._curr_chunk_index = self._start_index_chunk
self._curr_sample_index = self._start_index_sample
@@ -206,6 +212,7 @@ def __next__(self) -> Any:
"""Returns the next sample.
If necessary, this also loads the new chunks.
+
"""
self._check_if_sharded()
self._ensure_chunks_loaded()
@@ -243,6 +250,7 @@ def state_dict(self, returned_samples: int, num_workers: int) -> Dict[str, Any]:
returned_samples: the number of totally returned samples by the dataloader(s) (across all distributed
training processes).
num_workers: number of dataloader workers per distributed training process.
+
"""
# compute indices locally again since other workers may have different offsets
@@ -275,6 +283,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Note:
Some of the changes only take effect when creating a new iterator
+
"""
state_dict = deepcopy(state_dict)
self._start_index_chunk = state_dict.pop("current_chunk")
@@ -322,6 +331,7 @@ def _apply_sharding(self) -> None:
"""Shards the chunks if necessary.
No-op if already sharded
+
"""
if not self._local_chunks:
num_shards = self._env.num_shards
@@ -365,6 +375,7 @@ def _shuffle_if_necessary(
first_chunk_index: The point to which the generator should be replayed
shuffle_chunk_order: Whether to shuffle the order of chunks
shuffle_sample_order: Whether to shuffle the order of samples within a chunk
+
"""
# re-seed generator
if self._generator is not None and self._initial_generator_state is not None:
diff --git a/src/lightning/data/datasets/mapping.py b/src/lightning/data/datasets/mapping.py
index 046824cb2d38a..811c8fa909532 100644
--- a/src/lightning/data/datasets/mapping.py
+++ b/src/lightning/data/datasets/mapping.py
@@ -15,6 +15,7 @@ class LightningDataset(_Dataset, ABC):
data_source: path of data directory. ex. s3://mybucket/path
backend: storage location of the data_source. current options are "s3" or "local"
path_to_index_file: path to index file that lists all file contents of the data_source.
+
"""
def __init__(
@@ -36,6 +37,7 @@ def get_index(self) -> Any:
Returns:
The contents of the index file (all the file paths in the data_source)
+
"""
if not os.path.isfile(self.index_file):
get_index(self.data_source, self.index_file)
@@ -49,6 +51,7 @@ def __getitem__(self, idx: int) -> Any:
Returns:
The loaded item
+
"""
file_path = self.files[idx]
@@ -66,6 +69,7 @@ def load_sample(self, file_path: str, stream: OpenCloudFileObj) -> Any:
"""Loads each sample in the dataset.
Any data prep/cleaning logic goes here. For ex. image transformations, text cleaning, etc.
+
"""
pass
diff --git a/src/lightning/data/fileio.py b/src/lightning/data/fileio.py
index a6e2ec887d301..abacb1acfe755 100644
--- a/src/lightning/data/fileio.py
+++ b/src/lightning/data/fileio.py
@@ -21,6 +21,7 @@ def path_to_url(path: str, bucket_name: str, bucket_root_path: str = "/") -> str
Returns:
Full S3 url path
+
"""
if not path.startswith(bucket_root_path):
raise ValueError(f"Cannot create a path from {path} relative to {bucket_root_path}")
@@ -36,6 +37,7 @@ def open_single_file(
Returns:
The opened file stream.
+
"""
from torchdata.datapipes.iter import FSSpecFileOpener, IterableWrapper
@@ -54,6 +56,7 @@ def open_single_file_with_retry(
Returns:
The opened file stream.
+
"""
from torchdata.datapipes.iter import FSSpecFileOpener, IterableWrapper
@@ -83,6 +86,7 @@ class OpenCloudFileObj:
mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default).
kwargs_for_open: Optional Dict to specify kwargs for opening files (``fs.open()``).
+
"""
def __init__(
diff --git a/src/lightning/fabric/_graveyard/tpu.py b/src/lightning/fabric/_graveyard/tpu.py
index 2a45f928b3567..b38cfc470840c 100644
--- a/src/lightning/fabric/_graveyard/tpu.py
+++ b/src/lightning/fabric/_graveyard/tpu.py
@@ -34,6 +34,7 @@ class SingleTPUStrategy(SingleDeviceXLAStrategy):
"""Legacy class.
Use :class:`~lightning.fabric.strategies.single_xla.SingleDeviceXLAStrategy` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -50,6 +51,7 @@ class TPUAccelerator(XLAAccelerator):
"""Legacy class.
Use :class:`~lightning.fabric.accelerators.xla.XLAAccelerator` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -63,6 +65,7 @@ class TPUPrecision(XLAPrecision):
"""Legacy class.
Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -76,6 +79,7 @@ class TPUBf16Precision(XLABf16Precision):
"""Legacy class.
Use :class:`~lightning.fabric.plugins.precision.xlabf16.XLABf16Precision` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
diff --git a/src/lightning/fabric/accelerators/accelerator.py b/src/lightning/fabric/accelerators/accelerator.py
index f843f05f21ccf..3a8aa85ad041d 100644
--- a/src/lightning/fabric/accelerators/accelerator.py
+++ b/src/lightning/fabric/accelerators/accelerator.py
@@ -25,6 +25,7 @@ class Accelerator(ABC):
An Accelerator is meant to deal with one type of hardware.
.. warning:: Writing your own accelerator is an :ref:`experimental ` feature.
+
"""
@abstractmethod
diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py
index 2f19529597f2e..c79c943afcdf1 100644
--- a/src/lightning/fabric/accelerators/cuda.py
+++ b/src/lightning/fabric/accelerators/cuda.py
@@ -88,6 +88,7 @@ def find_usable_cuda_devices(num_devices: int = -1) -> List[int]:
Warning:
If multiple processes call this function at the same time, there can be race conditions in the case where
both processes determine that the device is unoccupied, leading into one of them crashing later on.
+
"""
visible_devices = _get_all_visible_cuda_devices()
if not visible_devices:
@@ -128,6 +129,7 @@ def _get_all_visible_cuda_devices() -> List[int]:
Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you
have 8 physical GPUs. If ``CUDA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]``
because these are the three visible GPUs after applying the mask ``CUDA_VISIBLE_DEVICES``.
+
"""
return list(range(num_cuda_devices()))
@@ -135,8 +137,7 @@ def _get_all_visible_cuda_devices() -> List[int]:
# TODO: Remove once minimum supported PyTorch version is 2.0
@contextmanager
def _patch_cuda_is_available() -> Generator:
- """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
- possible."""
+ """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible."""
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_2_0:
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
# otherwise, patching is_available could lead to attribute errors or infinite recursion
@@ -156,6 +157,7 @@ def num_cuda_devices() -> int:
Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
+
"""
if _TORCH_GREATER_EQUAL_2_0:
return torch.cuda.device_count()
@@ -171,6 +173,7 @@ def is_cuda_available() -> bool:
Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
+
"""
# We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning.fabric.__init__.py
return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices() > 0
@@ -311,6 +314,7 @@ def _device_count_nvml() -> int:
"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
Negative value is returned if NVML discovery or initialization has failed.
+
"""
visible_devices = _parse_visible_devices()
if not visible_devices:
diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py
index cb6ffddbd9e66..1126f01d1eada 100644
--- a/src/lightning/fabric/accelerators/mps.py
+++ b/src/lightning/fabric/accelerators/mps.py
@@ -26,6 +26,7 @@ class MPSAccelerator(Accelerator):
"""Accelerator for Metal Apple Silicon GPU devices.
.. warning:: Use of this accelerator beyond import and instantiation is experimental.
+
"""
def setup_device(self, device: torch.device) -> None:
diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py
index f8d79dc1b602d..68b2b98f45989 100644
--- a/src/lightning/fabric/accelerators/registry.py
+++ b/src/lightning/fabric/accelerators/registry.py
@@ -41,6 +41,7 @@ def __init__(self, a, b):
or
AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True)
+
"""
def register(
@@ -59,6 +60,7 @@ def register(
description : accelerator description
override : overrides the registered accelerator, if True
init_params: parameters to initialize the accelerator
+
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f"`name` must be a str, found {name}")
@@ -87,6 +89,7 @@ def get(self, name: str, default: Optional[Any] = None) -> Any:
Args:
name (str): the name that identifies a accelerator, e.g. "gpu"
+
"""
if name in self:
data = self[name]
diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py
index fbe29eee641ec..207eefe5338ed 100644
--- a/src/lightning/fabric/accelerators/xla.py
+++ b/src/lightning/fabric/accelerators/xla.py
@@ -26,6 +26,7 @@ class XLAAccelerator(Accelerator):
"""Accelerator for XLA devices, normally TPUs.
.. warning:: Use of this accelerator beyond import and instantiation is experimental.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py
index abcb8f195abcb..24943a528b47e 100644
--- a/src/lightning/fabric/cli.py
+++ b/src/lightning/fabric/cli.py
@@ -33,8 +33,8 @@
def _get_supported_strategies() -> List[str]:
- """Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from
- the CLI or ones that require further configuration by the user."""
+ """Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the
+ CLI or ones that require further configuration by the user."""
available_strategies = STRATEGY_REGISTRY.available_strategies()
excluded = r".*(spawn|fork|notebook|xla|tpu|offload).*"
return [strategy for strategy in available_strategies if not re.match(excluded, strategy)]
@@ -122,6 +122,7 @@ def _run_model(**kwargs: Any) -> None:
SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
there.
+
"""
script_args = list(kwargs.pop("script_args", []))
main(args=Namespace(**kwargs), script_args=script_args)
@@ -131,6 +132,7 @@ def _set_env_variables(args: Namespace) -> None:
"""Set the environment variables for the new processes.
The Fabric connector will parse the arguments set here.
+
"""
os.environ["LT_CLI_USED"] = "1"
if args.accelerator is not None:
diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py
index 401d6d491284a..0852e1e179bd1 100644
--- a/src/lightning/fabric/connector.py
+++ b/src/lightning/fabric/connector.py
@@ -98,6 +98,7 @@ class _Connector:
priorities which to take when:
A. Class > str
B. Strategy > Accelerator/precision/plugins
+
"""
def __init__(
@@ -182,6 +183,7 @@ def _check_config_and_set_final_flags(
4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others.
Additionally, other flags such as `precision` can populate the list with the
corresponding plugin instances.
+
"""
if plugins is not None:
plugins = [plugins] if not isinstance(plugins, list) else plugins
@@ -390,8 +392,8 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "ddp"
def _check_strategy_and_fallback(self) -> None:
- """Checks edge cases when the strategy selection was a string input, and we need to fall back to a
- different choice depending on other parameters or the environment."""
+ """Checks edge cases when the strategy selection was a string input, and we need to fall back to a different
+ choice depending on other parameters or the environment."""
# current fallback and check logic only apply to user pass in str config and object config
# TODO this logic should apply to both str and object config
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag
diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py
index 2a9e6088b40ea..0f229a5538c54 100644
--- a/src/lightning/fabric/fabric.py
+++ b/src/lightning/fabric/fabric.py
@@ -96,6 +96,7 @@ class Fabric:
can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more
information.
+
"""
def __init__(
@@ -147,6 +148,7 @@ def device(self) -> torch.device:
"""The current device this process runs on.
Use this to create tensors directly on the device if needed.
+
"""
return self._strategy.root_device
@@ -189,6 +191,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
"""All the code inside this run method gets accelerated by Fabric.
You can pass arbitrary arguments to this function when overriding it.
+
"""
def setup(
@@ -207,6 +210,7 @@ def setup(
Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
+
"""
self._validate_setup(module, optimizers)
original_module = module
@@ -264,6 +268,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
Returns:
The wrapped model.
+
"""
self._validate_setup_module(module)
original_module = module
@@ -300,6 +305,7 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu
Returns:
The wrapped optimizer(s).
+
"""
self._validate_setup_optimizers(optimizers)
optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
@@ -326,6 +332,7 @@ def setup_dataloaders(
Returns:
The wrapped dataloaders, in the same order they were passed in.
+
"""
self._validate_setup_dataloaders(dataloaders)
dataloaders = [
@@ -353,6 +360,7 @@ def _setup_dataloader(
Returns:
The wrapped dataloader.
+
"""
if use_distributed_sampler and self._requires_distributed_sampler(dataloader):
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)
@@ -419,6 +427,7 @@ def clip_gradients(
norm_type: The type of norm if `max_norm` was passed. Can be ``'inf'`` for infinity norm.
Default is the 2-norm.
error_if_nonfinite: An error is raised if the total norm of the gradients is NaN or infinite.
+
"""
if clip_val is not None and max_norm is not None:
raise ValueError(
@@ -444,6 +453,7 @@ def autocast(self) -> Generator[None, None, None]:
Use this only if the `forward` method of your model does not cover all operations you wish to run with the
chosen precision setting.
+
"""
with self._precision.forward_context():
yield
@@ -461,8 +471,8 @@ def to_device(self, obj: Any) -> Any:
...
def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
- """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already
- on that device.
+ """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on
+ that device.
Args:
obj: An object to move to the device. Can be an instance of :class:`torch.nn.Module`, a tensor, or a
@@ -470,6 +480,7 @@ def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tens
Returns:
A reference to the object that was moved to the new device.
+
"""
if isinstance(obj, nn.Module):
self._accelerator.setup_device(self.device)
@@ -482,6 +493,7 @@ def print(self, *args: Any, **kwargs: Any) -> None:
process in each machine.
Arguments passed to this method are forwarded to the Python built-in :func:`print` function.
+
"""
if self.local_rank == 0:
print(*args, **kwargs)
@@ -492,6 +504,7 @@ def barrier(self, name: Optional[str] = None) -> None:
Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization
will cause your program to slow down. This method needs to be called on all processes. Failing to do so will
cause your program to stall forever.
+
"""
self._validate_launched()
self._strategy.barrier(name=name)
@@ -508,6 +521,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
Return:
The transferred data, the same value on every rank.
+
"""
self._validate_launched()
return self._strategy.broadcast(obj, src=src)
@@ -528,6 +542,7 @@ def all_gather(
Return:
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape.
+
"""
self._validate_launched()
group = group if group is not None else torch.distributed.group.WORLD
@@ -554,6 +569,7 @@ def all_reduce(
Return:
A tensor of the same shape as the input with values reduced pointwise across processes. The same is
applied to tensors in a collection if a collection is given as input.
+
"""
self._validate_launched()
group = group if group is not None else torch.distributed.group.WORLD
@@ -573,6 +589,7 @@ def rank_zero_first(self, local: bool = False) -> Generator:
with fabric.rank_zero_first():
dataset = MNIST("datasets/", download=True)
+
"""
rank = self.local_rank if local else self.global_rank
if rank > 0:
@@ -604,6 +621,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
module: The module for which to control the gradient synchronization.
enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not
skip.
+
"""
module = _unwrap_compiled(module)
if not isinstance(module, _FabricModule):
@@ -633,6 +651,7 @@ def sharded_model(self) -> Generator:
"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
.. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead.
+
"""
rank_zero_deprecation("`Fabric.sharded_model()` is deprecated in favor of `Fabric.init_module()`.")
if isinstance(self.strategy, _Sharded):
@@ -643,10 +662,11 @@ def sharded_model(self) -> Generator:
@contextmanager
def init_tensor(self) -> Generator:
- """Tensors that you instantiate under this context manager will be created on the device right away and
- have the right data type depending on the precision setting in Fabric.
+ """Tensors that you instantiate under this context manager will be created on the device right away and have
+ the right data type depending on the precision setting in Fabric.
The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.
+
"""
if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
rank_zero_warn(
@@ -670,6 +690,7 @@ def init_module(self, empty_init: Optional[bool] = None) -> Generator:
empty_init: Whether to initialize the model with empty weights (uninitialized memory).
If ``None``, the strategy will decide. Some strategies may not support all options.
Set this to ``True`` if you are loading a checkpoint into a large model. Requires `torch >= 1.13`.
+
"""
if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu":
rank_zero_warn(
@@ -700,6 +721,7 @@ def save(
filter: An optional dictionary containing filter callables that return a boolean indicating whether the
given item should be saved (``True``) or filtered out (``False``). Each filter key should match a
state key, where its filter will be applied to the ``state_dict`` generated.
+
"""
if filter is not None:
if not isinstance(filter, dict):
@@ -734,6 +756,7 @@ def load(
Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
given, the full checkpoint will be returned.
+
"""
unwrapped_state = _unwrap_objects(state)
remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
@@ -760,6 +783,7 @@ def load_raw(self, path: Union[str, Path], obj: Union[nn.Module, Optimizer], str
obj: A :class:`~torch.nn.Module` or :class:`~torch.optim.Optimizer` instance.
strict: Whether to enforce that the keys in the module's state-dict match the keys in the checkpoint.
Does not apply to optimizers.
+
"""
obj = _unwrap_objects(obj)
self._strategy.load_checkpoint(path=path, state=obj, strict=strict)
@@ -782,6 +806,7 @@ def launch(self, function: Callable[["Fabric"], Any] = _do_nothing, *args: Any,
``launch()`` from your code.
``launch()`` is a no-op when called multiple times and no function is passed in.
+
"""
if _is_using_cli():
raise RuntimeError(
@@ -825,6 +850,7 @@ def on_train_epoch_end(self, results):
fabric = Fabric(callbacks=[MyCallback()])
fabric.call("on_train_epoch_end", results={...})
+
"""
for callback in self._callbacks:
method = getattr(callback, hook_name, None)
@@ -853,6 +879,7 @@ def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
graph automatically.
step: Optional step number. Most Logger implementations auto-increment the step value by one with every
log call. You can specify your own value here.
+
"""
self.log_dict(metrics={name: value}, step=step)
@@ -864,6 +891,7 @@ def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None) -> No
Any :class:`torch.Tensor` in the dictionary get detached from the graph automatically.
step: Optional step number. Most Logger implementations auto-increment this value by one with every
log call. You can specify your own value here.
+
"""
metrics = convert_tensors_to_scalars(metrics)
for logger in self._loggers:
@@ -874,6 +902,7 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
"""Helper function to seed everything without explicitly importing Lightning.
See :func:`lightning.fabric.utilities.seed.seed_everything` for more details.
+
"""
if workers is None:
# Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new
diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py
index fd0c69c3ab439..d55860056e271 100644
--- a/src/lightning/fabric/loggers/csv_logs.py
+++ b/src/lightning/fabric/loggers/csv_logs.py
@@ -49,6 +49,7 @@ class CSVLogger(Logger):
logger = CSVLogger("path/to/logs/root", name="my_model")
logger.log_metrics({"loss": 0.235, "acc": 0.75})
logger.finalize("success")
+
"""
LOGGER_JOIN_CHAR = "-"
@@ -77,6 +78,7 @@ def name(self) -> str:
Returns:
The name of the experiment.
+
"""
return self._name
@@ -86,6 +88,7 @@ def version(self) -> Union[int, str]:
Returns:
The version of the experiment if it is specified, else the next version.
+
"""
if self._version is None:
self._version = self._get_next_version()
@@ -102,6 +105,7 @@ def log_dir(self) -> str:
By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the
constructor's version parameter instead of ``None`` or an int.
+
"""
# create a pseudo standard path
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
@@ -110,12 +114,12 @@ def log_dir(self) -> str:
@property
@rank_zero_experiment
def experiment(self) -> "_ExperimentWriter":
- """Actual ExperimentWriter object. To use ExperimentWriter features anywhere in your code, do the
- following.
+ """Actual ExperimentWriter object. To use ExperimentWriter features anywhere in your code, do the following.
Example::
self.logger.experiment.some_experiment_writer_function()
+
"""
if self._experiment is not None:
return self._experiment
@@ -177,6 +181,7 @@ class _ExperimentWriter:
Args:
log_dir: Directory for the experiment logs
+
"""
NAME_METRICS_FILE = "metrics.csv"
diff --git a/src/lightning/fabric/loggers/logger.py b/src/lightning/fabric/loggers/logger.py
index 8efa974bd4860..5647ab9c1c7a2 100644
--- a/src/lightning/fabric/loggers/logger.py
+++ b/src/lightning/fabric/loggers/logger.py
@@ -39,8 +39,8 @@ def version(self) -> Optional[Union[int, str]]:
@property
def root_dir(self) -> Optional[str]:
- """Return the root directory where all versions of an experiment get saved, or `None` if the logger does
- not save data locally."""
+ """Return the root directory where all versions of an experiment get saved, or `None` if the logger does not
+ save data locally."""
return None
@property
@@ -61,6 +61,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
+
"""
pass
@@ -72,6 +73,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any,
params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters
args: Optional positional arguments, depends on the specific logger being used
kwargs: Optional keyword arguments, depends on the specific logger being used
+
"""
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
@@ -80,6 +82,7 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None
Args:
model: the model with an implementation of ``forward``.
input_array: input passes to `model.forward`
+
"""
pass
@@ -91,6 +94,7 @@ def finalize(self, status: str) -> None:
Args:
status: Status that the experiment finished with (e.g. success, failed, aborted)
+
"""
self.save()
diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py
index 3708ac9b73d1b..8098881c7d84b 100644
--- a/src/lightning/fabric/loggers/tensorboard.py
+++ b/src/lightning/fabric/loggers/tensorboard.py
@@ -75,6 +75,7 @@ class TensorBoardLogger(Logger):
logger.log_hyperparams({"epochs": 5, "optimizer": "Adam"})
logger.log_metrics({"acc": 0.75})
logger.finalize("success")
+
"""
LOGGER_JOIN_CHAR = "-"
@@ -113,6 +114,7 @@ def name(self) -> str:
Returns:
The name of the experiment.
+
"""
return self._name
@@ -122,6 +124,7 @@ def version(self) -> Union[int, str]:
Returns:
The experiment version if specified else the next version.
+
"""
if self._version is None:
self._version = self._get_next_version()
@@ -133,6 +136,7 @@ def root_dir(self) -> str:
Returns:
The local path to the save directory where the TensorBoard experiments are saved.
+
"""
return self._root_dir
@@ -142,6 +146,7 @@ def log_dir(self) -> str:
By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the
constructor's version parameter instead of ``None`` or an int.
+
"""
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
log_dir = os.path.join(self.root_dir, self.name, version)
@@ -157,6 +162,7 @@ def sub_dir(self) -> Optional[str]:
Returns:
The local path to the sub directory where the TensorBoard experiments are saved.
+
"""
return self._sub_dir
@@ -168,6 +174,7 @@ def experiment(self) -> "SummaryWriter":
Example::
logger.experiment.some_tensorboard_function()
+
"""
if self._experiment is not None:
return self._experiment
@@ -210,12 +217,13 @@ def log_hyperparams( # type: ignore[override]
self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None
) -> None:
"""Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the
- hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs
- to display the new ones with hyperparameters.
+ hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to
+ display the new ones with hyperparameters.
Args:
params: a dictionary-like container with the hyperparameters
metrics: Dictionary with metric names as keys and measured quantities as values
+
"""
params = _convert_params(params)
diff --git a/src/lightning/fabric/plugins/collectives/collective.py b/src/lightning/fabric/plugins/collectives/collective.py
index 9a2c399883ea3..5c655f189c561 100644
--- a/src/lightning/fabric/plugins/collectives/collective.py
+++ b/src/lightning/fabric/plugins/collectives/collective.py
@@ -13,6 +13,7 @@ class Collective(ABC):
Supports communications between multiple processes and multiple nodes. A collective owns a group.
.. warning:: This is an :ref:`experimental ` feature which is still in development.
+
"""
def __init__(self) -> None:
@@ -120,6 +121,7 @@ def create_group(self, **kwargs: Any) -> Self:
This assumes that :meth:`~lightning.fabric.plugins.collectives.Collective.init_group` has been
called already by the user.
+
"""
if self._group is not None:
raise RuntimeError(f"`{type(self).__name__}` already owns a group.")
diff --git a/src/lightning/fabric/plugins/collectives/single_device.py b/src/lightning/fabric/plugins/collectives/single_device.py
index 88c24c489235c..fee7a05f79e96 100644
--- a/src/lightning/fabric/plugins/collectives/single_device.py
+++ b/src/lightning/fabric/plugins/collectives/single_device.py
@@ -10,6 +10,7 @@ class SingleDeviceCollective(Collective):
"""Support for collective operations on a single device (no-op).
.. warning:: This is an :ref:`experimental ` feature which is still in development.
+
"""
@property
diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py
index 05c2a5b4b2cd5..50b9a4997554a 100644
--- a/src/lightning/fabric/plugins/collectives/torch_collective.py
+++ b/src/lightning/fabric/plugins/collectives/torch_collective.py
@@ -21,6 +21,7 @@ class TorchCollective(Collective):
"""Collective operations using `torch.distributed `__.
.. warning:: This is an :ref:`experimental ` feature which is still in development.
+
"""
manages_default_group = False
diff --git a/src/lightning/fabric/plugins/environments/kubeflow.py b/src/lightning/fabric/plugins/environments/kubeflow.py
index 967c4682f2184..3b14fbccae23d 100644
--- a/src/lightning/fabric/plugins/environments/kubeflow.py
+++ b/src/lightning/fabric/plugins/environments/kubeflow.py
@@ -28,6 +28,7 @@ class KubeflowEnvironment(ClusterEnvironment):
.. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/
.. _Kubeflow: https://www.kubeflow.org
+
"""
@property
diff --git a/src/lightning/fabric/plugins/environments/lightning.py b/src/lightning/fabric/plugins/environments/lightning.py
index efb4968cc6950..8a717e3bf4f7e 100644
--- a/src/lightning/fabric/plugins/environments/lightning.py
+++ b/src/lightning/fabric/plugins/environments/lightning.py
@@ -32,6 +32,7 @@ class LightningEnvironment(ClusterEnvironment):
If the main address and port are not provided, the default environment will choose them
automatically. It is recommended to use this default environment for single-node distributed
training as it provides a convenient way to launch the training script.
+
"""
def __init__(self) -> None:
@@ -46,6 +47,7 @@ def creates_processes_externally(self) -> bool:
If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
process launcher/job scheduler and Lightning will not launch new processes.
+
"""
return "LOCAL_RANK" in os.environ
@@ -93,6 +95,7 @@ def find_free_network_port() -> int:
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
+
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py
index 8500a3e40fd93..2eb089f58065a 100644
--- a/src/lightning/fabric/plugins/environments/lsf.py
+++ b/src/lightning/fabric/plugins/environments/lsf.py
@@ -44,6 +44,7 @@ class LSFEnvironment(ClusterEnvironment):
``JSM_NAMESPACE_RANK``
The global rank for the task. This environment variable is set by ``jsrun``
+
"""
def __init__(self) -> None:
@@ -128,6 +129,7 @@ def _get_node_rank(self) -> int:
The node rank is determined by the position of the current node in the list of hosts used in the job. This is
calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
+
"""
hosts = self._read_hosts()
count: Dict[str, int] = {}
@@ -143,6 +145,7 @@ def _read_hosts() -> List[str]:
LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
Each job is assigned a launch node. This launch node will be the first node in the list contained in
``LSB_DJOB_RANKFILE``.
+
"""
var = "LSB_DJOB_RANKFILE"
rankfile = os.environ.get(var)
@@ -161,6 +164,7 @@ def _get_main_address(self) -> str:
"""A helper for getting the main address.
The main address is assigned to the first node in the list of nodes used for the job.
+
"""
hosts = self._read_hosts()
return hosts[0]
@@ -170,6 +174,7 @@ def _get_main_port() -> int:
"""A helper function for accessing the main port.
Uses the LSF job ID so all ranks can compute the main port.
+
"""
# check for user-specified main port
if "MASTER_PORT" in os.environ:
diff --git a/src/lightning/fabric/plugins/environments/mpi.py b/src/lightning/fabric/plugins/environments/mpi.py
index a518da5f66c0b..e40fe8b027790 100644
--- a/src/lightning/fabric/plugins/environments/mpi.py
+++ b/src/lightning/fabric/plugins/environments/mpi.py
@@ -31,6 +31,7 @@ class MPIEnvironment(ClusterEnvironment):
"""An environment for running on clusters with processes created through MPI.
Requires the installation of the `mpi4py` package. See also: https://github.com/mpi4py/mpi4py
+
"""
def __init__(self) -> None:
diff --git a/src/lightning/fabric/plugins/environments/slurm.py b/src/lightning/fabric/plugins/environments/slurm.py
index 97fceb0218284..b951ebedd5ae4 100644
--- a/src/lightning/fabric/plugins/environments/slurm.py
+++ b/src/lightning/fabric/plugins/environments/slurm.py
@@ -39,6 +39,7 @@ class SLURMEnvironment(ClusterEnvironment):
rescheduled gets determined by the owner of this plugin.
requeue_signal: The signal that SLURM will send to indicate that the job should be requeued. Defaults to
SIGUSR1 on Unix.
+
"""
def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Signals] = None) -> None:
@@ -149,6 +150,7 @@ def resolve_root_node_address(nodes: str) -> str:
- a space-separated list of host names, e.g., 'host0 host1 host3' yields 'host0' as the root
- a comma-separated list of host names, e.g., 'host0,host1,host3' yields 'host0' as the root
- the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root
+
"""
nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range
nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number
@@ -161,6 +163,7 @@ def _validate_srun_used() -> None:
Parallel jobs (multi-GPU, multi-node) in SLURM are launched by prepending `srun` in front of the Python command.
Not doing so will result in processes hanging, which is a frequent user error. Lightning will emit a warning if
`srun` is found but not used.
+
"""
if _IS_WINDOWS:
return
@@ -176,12 +179,12 @@ def _validate_srun_used() -> None:
@staticmethod
def _validate_srun_variables() -> None:
- """Checks for conflicting or incorrectly set variables set through `srun` and raises a useful error
- message.
+ """Checks for conflicting or incorrectly set variables set through `srun` and raises a useful error message.
Right now, we only check for the most common user errors. See
`the srun docs `_
for a complete list of supported srun variables.
+
"""
ntasks = int(os.environ.get("SLURM_NTASKS", "1"))
if ntasks > 1 and "SLURM_NTASKS_PER_NODE" not in os.environ:
diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py
index 0aa4671684fd0..97657087be380 100644
--- a/src/lightning/fabric/plugins/environments/xla.py
+++ b/src/lightning/fabric/plugins/environments/xla.py
@@ -25,6 +25,7 @@ class XLAEnvironment(ClusterEnvironment):
A list of environment variables set by XLA can be found
`here `_.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py
index 44e9f596d4f1c..93e3a67b7bea6 100644
--- a/src/lightning/fabric/plugins/io/checkpoint_io.py
+++ b/src/lightning/fabric/plugins/io/checkpoint_io.py
@@ -42,6 +42,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
checkpoint: dict containing model and trainer state
path: write-target path
storage_options: Optional parameters when saving the model/training states.
+
"""
@abstractmethod
@@ -54,6 +55,7 @@ def load_checkpoint(self, path: _PATH, map_location: Optional[Any] = None) -> Di
locations.
Returns: The loaded checkpoint.
+
"""
@abstractmethod
@@ -62,6 +64,7 @@ def remove_checkpoint(self, path: _PATH) -> None:
Args:
path: Path to checkpoint
+
"""
def teardown(self) -> None:
diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py
index 646f29ffcbbe4..31f90d2b69f22 100644
--- a/src/lightning/fabric/plugins/io/torch_io.py
+++ b/src/lightning/fabric/plugins/io/torch_io.py
@@ -25,10 +25,11 @@
class TorchCheckpointIO(CheckpointIO):
- """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
- respectively, common for most use cases.
+ """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively,
+ common for most use cases.
.. warning:: This is an :ref:`experimental ` feature.
+
"""
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
@@ -42,6 +43,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
Raises:
TypeError:
If ``storage_options`` arg is passed in
+
"""
if storage_options is not None:
raise TypeError(
@@ -82,6 +84,7 @@ def remove_checkpoint(self, path: _PATH) -> None:
Args:
path: Path to checkpoint
+
"""
fs = get_filesystem(path)
if fs.exists(path):
diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py
index 509bdcdc6e568..4a5c3ef96bdd6 100644
--- a/src/lightning/fabric/plugins/io/xla.py
+++ b/src/lightning/fabric/plugins/io/xla.py
@@ -28,6 +28,7 @@ class XLACheckpointIO(TorchCheckpointIO):
"""CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies.
.. warning:: This is an :ref:`experimental ` feature.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -46,6 +47,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio
Raises:
TypeError:
If ``storage_options`` arg is passed in
+
"""
if storage_options is not None:
raise TypeError(
diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py
index d51baa5939804..6a1ebb17beb3d 100644
--- a/src/lightning/fabric/plugins/precision/deepspeed.py
+++ b/src/lightning/fabric/plugins/precision/deepspeed.py
@@ -43,6 +43,7 @@ class DeepSpeedPrecision(Precision):
Raises:
ValueError:
If unsupported ``precision`` is provided.
+
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py
index 05419a5d008b3..8a2623a141dca 100644
--- a/src/lightning/fabric/plugins/precision/double.py
+++ b/src/lightning/fabric/plugins/precision/double.py
@@ -36,6 +36,7 @@ def init_context(self) -> Generator[None, None, None]:
"""Instantiate module parameters or tensors in the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
@@ -47,6 +48,7 @@ def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py
index 826415568d20e..5c211dd025215 100644
--- a/src/lightning/fabric/plugins/precision/fsdp.py
+++ b/src/lightning/fabric/plugins/precision/fsdp.py
@@ -47,6 +47,7 @@ class FSDPPrecision(Precision):
Raises:
ValueError:
If unsupported ``precision`` is provided.
+
"""
def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None:
@@ -107,6 +108,7 @@ def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.mixed_precision_config.param_dtype)
diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py
index aa9ac52ffd2b3..e4b011c57f70b 100644
--- a/src/lightning/fabric/plugins/precision/half.py
+++ b/src/lightning/fabric/plugins/precision/half.py
@@ -28,6 +28,7 @@ class HalfPrecision(Precision):
Args:
precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``).
+
"""
precision: Literal["bf16-true", "16-true"] = "16-true"
@@ -44,6 +45,7 @@ def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
@@ -52,10 +54,10 @@ def init_context(self) -> Generator[None, None, None]:
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
- """A context manager to change the default tensor type when tensors get created during the module's
- forward.
+ """A context manager to change the default tensor type when tensors get created during the module's forward.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py
index 1add95da884ad..76017b32c7c22 100644
--- a/src/lightning/fabric/plugins/precision/precision.py
+++ b/src/lightning/fabric/plugins/precision/precision.py
@@ -33,6 +33,7 @@ class Precision:
"""Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
+
"""
precision: _PRECISION_INPUT_STR = "32-true"
@@ -41,6 +42,7 @@ def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
+
"""
return module
@@ -49,6 +51,7 @@ def init_context(self) -> Generator[None, None, None]:
"""Instantiate module parameters or tensors in the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
+
"""
yield
@@ -62,6 +65,7 @@ def convert_input(self, data: Any) -> Any:
This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is
torch.float32).
+
"""
return data
@@ -70,6 +74,7 @@ def convert_output(self, data: Any) -> Any:
This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is
torch.float32).
+
"""
return data
@@ -79,6 +84,7 @@ def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
Args:
tensor: The tensor that will be used for backpropagation
module: The module that was involved in producing the tensor and whose parameters need the gradients
+
"""
def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
@@ -87,6 +93,7 @@ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs
Args:
tensor: The tensor that will be used for backpropagation
model: The module that was involved in producing the tensor and whose parameters need the gradients
+
"""
tensor.backward(*args, **kwargs)
@@ -96,6 +103,7 @@ def post_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
Args:
tensor: The tensor that will be used for backpropagation
module: The module that was involved in producing the tensor and whose parameters need the gradients
+
"""
def optimizer_step(
@@ -110,6 +118,7 @@ def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The main params of the model.
Returns the plain model params here. Maybe different in other precision plugins.
+
"""
for group in optimizer.param_groups:
yield from group["params"]
@@ -122,6 +131,7 @@ def state_dict(self) -> Dict[str, Any]:
Returns:
A dictionary containing precision plugin state.
+
"""
return {}
@@ -131,6 +141,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Args:
state_dict: the precision plugin state returned by ``state_dict``.
+
"""
pass
@@ -138,4 +149,5 @@ def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
+
"""
diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py
index 3d8d0c4ccfdc9..23035b80b56b6 100644
--- a/src/lightning/fabric/plugins/precision/transformer_engine.py
+++ b/src/lightning/fabric/plugins/precision/transformer_engine.py
@@ -53,6 +53,7 @@ class TransformerEnginePrecision(Precision):
Support for FP8 in the linear layers with `precision='transformer-engine'` is currently limited to tensors with
shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your inputs
to conform to this restriction.
+
"""
precision: Literal["transformer-engine"] = "transformer-engine"
diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py
index c6c65ef4ddd1b..a912c5d438cd8 100644
--- a/src/lightning/fabric/strategies/ddp.py
+++ b/src/lightning/fabric/strategies/ddp.py
@@ -136,6 +136,7 @@ def all_reduce(
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
+
"""
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 15bb69817768b..0fe3aa6e1b299 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -220,6 +220,7 @@ def __init__(
load_full_weights: True when loading a single checkpoint file containing the model state dict
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.
+
"""
if not _DEEPSPEED_AVAILABLE:
raise ImportError(
@@ -313,6 +314,7 @@ def setup_module_and_optimizers(
Return:
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
deepspeed optimizer.
+
"""
if len(optimizers) != 1:
raise ValueError(
@@ -328,6 +330,7 @@ def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine":
"""Set up a module for inference (no optimizers).
For training, see :meth:`setup_module_and_optimizers`.
+
"""
self._deepspeed_engine, _ = self._initialize_engine(module)
return self._deepspeed_engine
@@ -336,6 +339,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Optimizers can only be set up jointly with the model in this strategy.
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together.
+
"""
raise NotImplementedError(self._err_msg_joint_setup_required())
@@ -386,6 +390,7 @@ def save_checkpoint(
ValueError:
When no :class:`deepspeed.DeepSpeedEngine` objects were found in the state, or when multiple
:class:`deepspeed.DeepSpeedEngine` objects were found.
+
"""
if storage_options is not None:
raise TypeError(
@@ -451,6 +456,7 @@ def load_checkpoint(
RuntimeError:
If DeepSpeed was unable to load the checkpoint due to missing files or because the checkpoint is
not in the expected DeepSpeed format.
+
"""
if isinstance(state, (Module, Optimizer)) or self.load_full_weights and self.zero_stage_3:
# This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from
@@ -565,6 +571,7 @@ def _initialize_engine(
"""Initialize one model and one optimizer with an optional learning rate scheduler.
This calls :func:`deepspeed.initialize` internally.
+
"""
import deepspeed
@@ -720,12 +727,13 @@ def _create_default_config(
return cfg
def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None:
- """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be
- sharded across processes before loading the state dictionary when using ZeRO stage 3. This is then
- automatically synced across processes.
+ """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded
+ across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced
+ across processes.
Args:
ckpt: The ckpt file.
+
"""
import deepspeed
diff --git a/src/lightning/fabric/strategies/dp.py b/src/lightning/fabric/strategies/dp.py
index d19d1a14d7757..99beeca150012 100644
--- a/src/lightning/fabric/strategies/dp.py
+++ b/src/lightning/fabric/strategies/dp.py
@@ -28,8 +28,8 @@
class DataParallelStrategy(ParallelStrategy):
- """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and
- each gets a split of the data."""
+ """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each
+ gets a split of the data."""
def __init__(
self,
diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py
index f5d2c5a8eff47..07236abb4066d 100644
--- a/src/lightning/fabric/strategies/fsdp.py
+++ b/src/lightning/fabric/strategies/fsdp.py
@@ -137,6 +137,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
a folder with as many files as the world size.
\**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
+
"""
def __init__(
@@ -296,6 +297,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify
that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the
flattened parameters.
+
"""
if _TORCH_GREATER_EQUAL_2_0:
return optimizer
@@ -414,6 +416,7 @@ def save_checkpoint(
optimizer state and other metadata. If the state-dict-type is ``'sharded'``, the checkpoint gets saved as a
directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it
creates a metadata file `meta.pt` with the rest of the user's state (only saved from rank 0).
+
"""
if not _TORCH_GREATER_EQUAL_2_0:
raise NotImplementedError(
@@ -511,6 +514,7 @@ def load_checkpoint(
The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a
directory of multiple files rather than a single file.
+
"""
if not _TORCH_GREATER_EQUAL_2_0:
raise NotImplementedError(
@@ -846,8 +850,7 @@ def _load_raw_module_state_from_path(path: Path, module: Module, strict: bool =
def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, strict: bool = True) -> None:
- """Loads the state dict into the module by gathering all weights first and then and writing back to each
- shard."""
+ """Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
with _get_full_state_dict_context(module, rank0_only=False):
module.load_state_dict(state_dict, strict=strict)
@@ -879,6 +882,7 @@ def _apply_optimizers_during_fsdp_backward(
By moving optimizer step invocation into the backward call we can free
gradients earlier and reduce peak memory.
+
"""
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles
diff --git a/src/lightning/fabric/strategies/launchers/launcher.py b/src/lightning/fabric/strategies/launchers/launcher.py
index f261f81124d5f..c22a14633eb76 100644
--- a/src/lightning/fabric/strategies/launchers/launcher.py
+++ b/src/lightning/fabric/strategies/launchers/launcher.py
@@ -23,6 +23,7 @@ class _Launcher(ABC):
Subclass this class and override any of the relevant methods to provide a custom implementation depending on
cluster environment, hardware, strategy, etc.
+
"""
@property
diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py
index 9d90f7953b0da..66c766ecb54ae 100644
--- a/src/lightning/fabric/strategies/launchers/multiprocessing.py
+++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py
@@ -50,6 +50,7 @@ class _MultiProcessingLauncher(_Launcher):
- 'fork': Preferable for IPython/Jupyter environments where 'spawn' is not available. Not available on
the Windows platform for example.
- 'forkserver': Alternative implementation to 'fork'.
+
"""
def __init__(
@@ -82,6 +83,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
function: The entry point for all launched processes.
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
+
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()
@@ -143,6 +145,7 @@ class _GlobalStateSnapshot:
# in worker process
snapshot.restore()
+
"""
use_deterministic_algorithms: bool
@@ -152,8 +155,7 @@ class _GlobalStateSnapshot:
@classmethod
def capture(cls) -> "_GlobalStateSnapshot":
- """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker
- process."""
+ """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker process."""
return cls(
use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
use_deterministic_algorithms_warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
@@ -175,6 +177,7 @@ def _check_bad_cuda_fork() -> None:
The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for
Lightning users.
+
"""
if not torch.cuda.is_initialized():
return
diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py
index 27d23b859f9ef..f171b3435ef59 100644
--- a/src/lightning/fabric/strategies/launchers/subprocess_script.py
+++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py
@@ -65,6 +65,7 @@ class _SubprocessScriptLauncher(_Launcher):
cluster_environment: A cluster environment that provides access to world size, node rank, etc.
num_processes: The number of processes to launch in the current node.
num_nodes: The total number of nodes that participate in this process group.
+
"""
def __init__(
@@ -91,6 +92,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
It is up to the implementation of this function to synchronize the processes, e.g., with barriers.
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
+
"""
if not self.cluster_environment.creates_processes_externally:
self._call_children_scripts()
diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py
index a538886018b74..d4a17256dbd4c 100644
--- a/src/lightning/fabric/strategies/launchers/xla.py
+++ b/src/lightning/fabric/strategies/launchers/xla.py
@@ -27,8 +27,8 @@
class _XLALauncher(_Launcher):
- r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at
- the end.
+ r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
+ end.
The main process in which this launcher is invoked creates N so-called worker processes (using the
`torch_xla` :func:`xmp.spawn`) that run the given function.
@@ -40,6 +40,7 @@ class _XLALauncher(_Launcher):
Args:
strategy: A reference to the strategy that is used together with this launcher
+
"""
def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None:
@@ -62,6 +63,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
function: The entry point for all launched processes.
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
+
"""
using_pjrt = _using_pjrt()
return_queue: Union[queue.Queue, mp.SimpleQueue]
diff --git a/src/lightning/fabric/strategies/registry.py b/src/lightning/fabric/strategies/registry.py
index 92c0417062d10..7956f7a95e7e7 100644
--- a/src/lightning/fabric/strategies/registry.py
+++ b/src/lightning/fabric/strategies/registry.py
@@ -40,6 +40,7 @@ def __init__(self, a, b):
or
StrategyRegistry.register("lightning", LightningStrategy, description="Super fast", a=1, b=True)
+
"""
def register(
@@ -58,6 +59,7 @@ def register(
description : strategy description
override : overrides the registered strategy, if True
init_params: parameters to initialize the strategy
+
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f"`name` must be a str, found {name}")
@@ -86,6 +88,7 @@ def get(self, name: str, default: Optional[Any] = None) -> Any:
Args:
name (str): the name that identifies a strategy, e.g. "deepspeed_stage_3"
+
"""
if name in self:
data = self[name]
diff --git a/src/lightning/fabric/strategies/single_device.py b/src/lightning/fabric/strategies/single_device.py
index 59ccf0810cd8f..3edda3faf7dc1 100644
--- a/src/lightning/fabric/strategies/single_device.py
+++ b/src/lightning/fabric/strategies/single_device.py
@@ -56,8 +56,8 @@ def module_to_device(self, module: Module) -> None:
module.to(self.root_device)
def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
- """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only
- operates with a single device, the reduction is simply the identity.
+ """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates
+ with a single device, the reduction is simply the identity.
Args:
tensor: the tensor to sync and reduce
@@ -66,6 +66,7 @@ def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | T
Return:
the unmodified input as reduction is not needed for single process operation
+
"""
return tensor
diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py
index 3e1913a6c1ad2..3aaa8db0daca0 100644
--- a/src/lightning/fabric/strategies/strategy.py
+++ b/src/lightning/fabric/strategies/strategy.py
@@ -102,6 +102,7 @@ def setup_environment(self) -> None:
This must be called by the framework at the beginning of every process, before any distributed communication
takes place.
+
"""
assert self.accelerator is not None
self.accelerator.setup_device(self.root_device)
@@ -111,6 +112,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
+
"""
return dataloader
@@ -131,6 +133,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> Generator:
Args:
empty_init: Whether to initialize the model with empty weights (uninitialized memory).
If ``None``, the strategy will decide. Some strategies may not support all options.
+
"""
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
with empty_init_context, self.tensor_init_context():
@@ -143,6 +146,7 @@ def setup_module_and_optimizers(
The returned objects are expected to be in the same order they were passed in. The default implementation will
call :meth:`setup_module` and :meth:`setup_optimizer` on the inputs.
+
"""
module = self.setup_module(module)
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
@@ -169,6 +173,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) ->
Args:
batch: The batch of samples to move to the correct device
device: The target device
+
"""
device = device or self.root_device
return move_data_to_device(batch, device)
@@ -189,6 +194,7 @@ def optimizer_step(
Args:
optimizer: the optimizer performing the step
**kwargs: Any extra arguments to ``optimizer.step``
+
"""
return self.precision.optimizer_step(optimizer, **kwargs)
@@ -200,6 +206,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
tensor: the tensor to all_gather
group: the process group to gather results from
sync_grads: flag that allows users to synchronize gradients for all_gather op
+
"""
@abstractmethod
@@ -216,6 +223,7 @@ def all_reduce(
group: the process group to reduce
reduce_op: the reduction operation. Defaults to 'mean'.
Can also be a string 'sum' or ReduceOp.
+
"""
@abstractmethod
@@ -224,6 +232,7 @@ def barrier(self, name: Optional[str] = None) -> None:
Args:
name: an optional name to pass into barrier.
+
"""
@abstractmethod
@@ -233,6 +242,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
Args:
obj: the object to broadcast
src: source rank
+
"""
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
@@ -256,6 +266,7 @@ def save_checkpoint(
filter: An optional dictionary containing filter callables that return a boolean indicating whether the
given item should be saved (``True``) or filtered out (``False``). Each filter key should match a
state key, where its filter will be applied to the ``state_dict`` generated.
+
"""
state = self._convert_stateful_objects_in_state(state, filter=filter or {})
if self.is_global_zero:
@@ -275,6 +286,7 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""Returns state of an optimizer.
Allows for syncing/collating optimizer state from processes in custom plugins.
+
"""
if hasattr(optimizer, "consolidate_state_dict"):
# there are optimizers like PyTorch's ZeroRedundancyOptimizer that shard their
@@ -307,6 +319,7 @@ def load_checkpoint(
Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is
given, the full checkpoint will be returned.
+
"""
torch.cuda.empty_cache()
checkpoint = self.checkpoint_io.load_checkpoint(path)
@@ -338,6 +351,7 @@ def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
+
"""
self.precision.teardown()
assert self.accelerator is not None
@@ -397,6 +411,7 @@ class _BackwardSyncControl(ABC):
The most common use-case is gradient accumulation. If a :class:`Strategy` implements this interface, the user can
implement their gradient accumulation loop very efficiently by disabling redundant gradient synchronization.
+
"""
@contextmanager
@@ -405,20 +420,21 @@ def no_backward_sync(self, module: Module) -> Generator:
"""Blocks the synchronization of gradients during the backward pass.
This is a context manager. It is only effective if it wraps a call to `.backward()`.
+
"""
class _Sharded(ABC):
- """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model
- parameters."""
+ """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters."""
@abstractmethod
@contextmanager
def module_sharded_context(self) -> Generator:
- """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding
- of parameters on creation.
+ """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of
+ parameters on creation.
By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time.
+
"""
yield
diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py
index f5050889a58c5..955b8bfda08c5 100644
--- a/src/lightning/fabric/strategies/xla.py
+++ b/src/lightning/fabric/strategies/xla.py
@@ -151,6 +151,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
sync_grads: flag that allows users to synchronize gradients for the all-gather operation.
Return:
A tensor of shape (world_size, ...)
+
"""
if not self._launched:
return tensor
@@ -246,6 +247,7 @@ def save_checkpoint(
storage_options: Additional options for the ``CheckpointIO`` plugin
filter: An optional dictionary of the same format as ``state`` mapping keys to callables that return a
boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``).
+
"""
import torch_xla.core.xla_model as xm
diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py
index 81e269a794019..e59c9843e09ff 100644
--- a/src/lightning/fabric/strategies/xla_fsdp.py
+++ b/src/lightning/fabric/strategies/xla_fsdp.py
@@ -188,6 +188,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify
that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the
flattened parameters.
+
"""
if _TORCH_GREATER_EQUAL_2_0:
return optimizer
@@ -210,12 +211,13 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
)
def optimizer_step(self, optimizer: Optimizable, **kwargs: Any) -> Any:
- """Overrides default tpu optimizer_step since FSDP should not call
- `torch_xla.core.xla_model.optimizer_step`. Performs the actual optimizer step.
+ """Overrides default tpu optimizer_step since FSDP should not call `torch_xla.core.xla_model.optimizer_step`.
+ Performs the actual optimizer step.
Args:
optimizer: the optimizer performing the step
**kwargs: Any extra arguments to ``optimizer.step``
+
"""
loss = optimizer.step(**kwargs)
import torch_xla.core.xla_model as xm
@@ -251,6 +253,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
sync_grads: flag that allows users to synchronize gradients for the all-gather operation.
Return:
A tensor of shape (world_size, ...)
+
"""
if not self._launched:
return tensor
@@ -342,6 +345,7 @@ def save_checkpoint(
If the user specifies sharded checkpointing, the directory will contain one file per process, with model- and
optimizer shards stored per file. If the user specifies full checkpointing, the directory will contain a
consolidated checkpoint combining all of the sharded checkpoints.
+
"""
if not _TORCH_GREATER_EQUAL_2_0:
raise NotImplementedError(
@@ -421,6 +425,7 @@ def load_checkpoint(
The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a
directory of multiple files rather than a single file.
+
"""
if not _TORCH_GREATER_EQUAL_2_0:
raise NotImplementedError(
diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py
index 1feedef96e18f..33231ccd19f99 100644
--- a/src/lightning/fabric/utilities/apply_func.py
+++ b/src/lightning/fabric/utilities/apply_func.py
@@ -56,6 +56,7 @@ class _TransferableDataType(ABC):
... return self
>>> isinstance(CustomObject(), _TransferableDataType)
True
+
"""
@classmethod
@@ -113,6 +114,7 @@ def convert_tensors_to_scalars(data: Any) -> Any:
Raises:
ValueError:
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
+
"""
def to_item(value: Tensor) -> Union[int, float, bool]:
diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py
index 17d5d33c7e060..4979e5db7c7f5 100644
--- a/src/lightning/fabric/utilities/cloud_io.py
+++ b/src/lightning/fabric/utilities/cloud_io.py
@@ -34,6 +34,7 @@ def _load(
Args:
path_or_url: Path or URL of the checkpoint.
map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
+
"""
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similar
@@ -65,6 +66,7 @@ def _atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None
accepts.
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
+
"""
bytesbuffer = io.BytesIO()
torch.save(checkpoint, bytesbuffer)
@@ -107,6 +109,7 @@ def _is_dir(fs: AbstractFileSystem, path: Union[str, Path], strict: bool = False
strict: A flag specific to Object Storage platforms. If set to ``False``, any non-existing path is considered
as a valid directory-like path. In such cases, the directory (and any non-existing parent directories)
will be created on the fly. Defaults to False.
+
"""
# Object storage fsspec's are inconsistent with other file systems because they do not have real directories,
# see for instance https://gcsfs.readthedocs.io/en/latest/api.html?highlight=makedirs#gcsfs.core.GCSFileSystem.mkdir
diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py
index 7b30e0944ad14..e6e53034b92ef 100644
--- a/src/lightning/fabric/utilities/data.py
+++ b/src/lightning/fabric/utilities/data.py
@@ -175,8 +175,8 @@ def _dataloader_init_kwargs_resolve_sampler(
sampler: Union[Sampler, Iterable],
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
- """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
- re-instantiation."""
+ """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-
+ instantiation."""
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None:
@@ -362,6 +362,7 @@ def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] =
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods.
+
"""
classes = get_all_subclasses(base_cls) | {base_cls}
for cls in classes:
@@ -399,6 +400,7 @@ def _replace_value_in_saved_args(
"""Tries to replace an argument value in a saved list of args and kwargs.
Returns a tuple indicating success of the operation and modified saved args and kwargs
+
"""
if replace_key in arg_names:
diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py
index 40a171134e849..cb5590c098cf0 100644
--- a/src/lightning/fabric/utilities/device_dtype_mixin.py
+++ b/src/lightning/fabric/utilities/device_dtype_mixin.py
@@ -55,8 +55,8 @@ def to(self, *args: Any, **kwargs: Any) -> Self:
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
"""Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers
- different objects. So it should be called before constructing optimizer if the module will live on GPU
- while being optimized.
+ different objects. So it should be called before constructing optimizer if the module will live on GPU while
+ being optimized.
Arguments:
device: If specified, all parameters will be copied to that device. If `None`, the current CUDA device
@@ -64,6 +64,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
Returns:
Module: self
+
"""
if device is None:
device = torch.device("cuda", torch.cuda.current_device())
diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py
index 65e363cb06d65..2aa8872e87812 100644
--- a/src/lightning/fabric/utilities/device_parser.py
+++ b/src/lightning/fabric/utilities/device_parser.py
@@ -113,8 +113,8 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in
def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]:
- """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of
- the GPUs is not available.
+ """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the
+ GPUs is not available.
Args:
gpus: List of ints corresponding to GPU indices
@@ -125,6 +125,7 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps:
Raises:
MisconfigurationException:
If machine has fewer available GPUs than requested.
+
"""
if sum((include_cuda, include_mps)) == 0:
raise ValueError("At least one gpu type should be specified!")
@@ -172,6 +173,7 @@ def _check_unique(device_ids: List[int]) -> None:
Raises:
MisconfigurationException:
If ``device_ids`` of GPUs aren't unique
+
"""
if len(device_ids) != len(set(device_ids)):
raise MisconfigurationException("Device ID's (GPU) must be unique.")
@@ -186,6 +188,7 @@ def _check_data_type(device_ids: object) -> None:
Raises:
TypeError:
If ``device_ids`` of GPU/TPUs aren't ``int``, ``str`` or sequence of ``int```
+
"""
msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints, but you passed"
if device_ids is None:
diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py
index 53615664512b1..c7f52161c4c38 100644
--- a/src/lightning/fabric/utilities/distributed.py
+++ b/src/lightning/fabric/utilities/distributed.py
@@ -39,6 +39,7 @@ def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Ten
Return:
gathered_result: List with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
+
"""
if group is None:
group = torch.distributed.group.WORLD
@@ -98,6 +99,7 @@ def _sync_ddp_if_available(
Return:
reduced value
+
"""
if torch.distributed.is_initialized():
return _sync_ddp(result, group=group, reduce_op=reduce_op)
@@ -115,6 +117,7 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
Return:
reduced value
+
"""
divide_by_world_size = False
@@ -197,6 +200,7 @@ def _all_gather_ddp_if_available(
Return:
A tensor of shape (world_size, batch, ...)
+
"""
if not torch.distributed.is_initialized():
return tensor
@@ -213,8 +217,8 @@ def _init_dist_connection(
world_size: Optional[int] = None,
**kwargs: Any,
) -> None:
- """Utility function to initialize distributed connection by setting env variables and initializing the
- distributed process group.
+ """Utility function to initialize distributed connection by setting env variables and initializing the distributed
+ process group.
Args:
cluster_environment: ``ClusterEnvironment`` instance
@@ -226,6 +230,7 @@ def _init_dist_connection(
Raises:
RuntimeError:
If ``torch.distributed`` is not available
+
"""
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py
index 52cf496cf023f..2031c74c1d3f6 100644
--- a/src/lightning/fabric/utilities/init.py
+++ b/src/lightning/fabric/utilities/init.py
@@ -30,6 +30,7 @@ class _EmptyInit(TorchFunctionMode):
with _EmptyInit():
model = BigModel()
model.load_state_dict(torch.load("checkpoint.pt"))
+
"""
def __init__(self, enabled: bool = True) -> None:
diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py
index db726ede05522..c3874262ca17e 100644
--- a/src/lightning/fabric/utilities/logger.py
+++ b/src/lightning/fabric/utilities/logger.py
@@ -27,6 +27,7 @@ def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[
Returns:
params as a dictionary
+
"""
# in case converting from namespace
if isinstance(params, Namespace):
@@ -46,6 +47,7 @@ def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
Returns:
dictionary with all callables sanitized
+
"""
def _sanitize_callable(val: Any) -> Any:
@@ -81,6 +83,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent
{'a/b': 123}
>>> _flatten_dict({5: {'a': 123}})
{'5/a': 123}
+
"""
result: Dict[str, Any] = {}
for k, v in params.items():
@@ -114,6 +117,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
'list': '[1, 2, 3]',
'namespace': 'Namespace(foo=3)',
'string': 'abc'}
+
"""
for k in params:
# convert relevant np scalars to python types first (instead of str)
@@ -136,6 +140,7 @@ def _add_prefix(
Returns:
Dictionary with prefix and separator inserted before each key
+
"""
if not prefix:
return metrics
diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py
index 609101fbb949b..4c1a1f7bb0ede 100644
--- a/src/lightning/fabric/utilities/registry.py
+++ b/src/lightning/fabric/utilities/registry.py
@@ -42,6 +42,7 @@ def _load_external_callbacks(group: str) -> List[Any]:
Return:
A list of all callbacks collected from external factories.
+
"""
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points
diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py
index c3c6852a76697..425db5ec354ff 100644
--- a/src/lightning/fabric/utilities/seed.py
+++ b/src/lightning/fabric/utilities/seed.py
@@ -17,8 +17,8 @@
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
- """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
- sets the following environment variables:
+ """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, sets
+ the following environment variables:
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
- `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``.
@@ -31,6 +31,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
:func:`~lightning.fabric.utilities.seed.pl_worker_init_function`.
+
"""
if seed is None:
env_seed = os.environ.get("PL_GLOBAL_SEED")
@@ -70,6 +71,7 @@ def reset_seed() -> None:
"""Reset the seed to the value that :func:`lightning.fabric.utilities.seed.seed_everything` previously set.
If :func:`lightning.fabric.utilities.seed.seed_everything` is unused, this function will do nothing.
+
"""
seed = os.environ.get("PL_GLOBAL_SEED", None)
if seed is None:
diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py
index 3a118eb56df33..0d840d5ec1ead 100644
--- a/src/lightning/fabric/utilities/spike.py
+++ b/src/lightning/fabric/utilities/spike.py
@@ -35,6 +35,7 @@ class SpikeDetection:
exclude_batches_path: Where to save the file that contains the batches to exclude.
Will default to current directory.
finite_only: If set to ``False``, consider non-finite values like NaN, inf and -inf a spike as well.
+
"""
def __init__(
diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py
index de940810a5f4c..906e9019fb739 100644
--- a/src/lightning/fabric/utilities/testing/_runif.py
+++ b/src/lightning/fabric/utilities/testing/_runif.py
@@ -57,6 +57,7 @@ def _runif_reasons(
This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set.
deepspeed: Require that microsoft/DeepSpeed is installed.
dynamo: Require that `torch.dynamo` is supported.
+
"""
reasons = []
kwargs = {} # used in conftest.py::pytest_collection_modifyitems
diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py
index 861acff348252..28732f9264876 100644
--- a/src/lightning/fabric/wrappers.py
+++ b/src/lightning/fabric/wrappers.py
@@ -39,14 +39,15 @@
class _FabricOptimizer:
def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None:
- """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the
- optimizer step calls to the strategy.
+ """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer
+ step calls to the strategy.
The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`.
Args:
optimizer: The optimizer to wrap
strategy: Reference to the strategy for handling the optimizer step
+
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_FabricOptimizer
@@ -96,6 +97,7 @@ def __init__(
original_module: The original, unmodified module as passed into the
:meth:`lightning.fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
on this wrapper should pass through to the original module.
+
"""
super().__init__()
self._forward_module = forward_module
@@ -108,8 +110,7 @@ def module(self) -> nn.Module:
return self._original_module or self._forward_module
def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Casts all inputs to the right precision and handles autocast for operations in the module forward
- method."""
+ """Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
args, kwargs = self._precision.convert_input((args, kwargs))
with self._precision.forward_context():
@@ -218,13 +219,14 @@ def __setattr__(self, name: str, value: Any) -> None:
class _FabricDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
- """The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to
- the device automatically if the device is specified.
+ """The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the
+ device automatically if the device is specified.
Args:
dataloader: The dataloader to wrap
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
+
"""
self.__dict__.update(dataloader.__dict__)
self._dataloader = dataloader
@@ -277,6 +279,7 @@ def _unwrap_compiled(obj: Any) -> Any:
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.
Use this function before instance checks against e.g. :class:`_FabricModule`.
+
"""
if not _TORCH_GREATER_EQUAL_2_0:
return obj
@@ -296,6 +299,7 @@ def is_wrapped(obj: object) -> bool:
Args:
obj: The object to test.
+
"""
obj = _unwrap_compiled(obj)
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))
diff --git a/src/lightning/pytorch/_graveyard/tpu.py b/src/lightning/pytorch/_graveyard/tpu.py
index 602dc585345bb..dde172973543d 100644
--- a/src/lightning/pytorch/_graveyard/tpu.py
+++ b/src/lightning/pytorch/_graveyard/tpu.py
@@ -35,6 +35,7 @@ class SingleTPUStrategy(SingleDeviceXLAStrategy):
"""Legacy class.
Use :class:`~lightning.pytorch.strategies.single_xla.SingleDeviceXLAStrategy` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -51,6 +52,7 @@ class TPUAccelerator(XLAAccelerator):
"""Legacy class.
Use :class:`~lightning.pytorch.accelerators.xla.XLAAccelerator` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -64,6 +66,7 @@ class TPUPrecisionPlugin(XLAPrecisionPlugin):
"""Legacy class.
Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecisionPlugin` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -78,6 +81,7 @@ class TPUBf16PrecisionPlugin(XLABf16PrecisionPlugin):
"""Legacy class.
Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLABf16PrecisionPlugin` instead.
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py
index 3f78b1f667c06..0490c2d86431c 100644
--- a/src/lightning/pytorch/accelerators/accelerator.py
+++ b/src/lightning/pytorch/accelerators/accelerator.py
@@ -23,6 +23,7 @@ class Accelerator(_Accelerator, ABC):
"""The Accelerator base class for Lightning PyTorch.
.. warning:: Writing your own accelerator is an :ref:`experimental ` feature.
+
"""
def setup(self, trainer: "pl.Trainer") -> None:
@@ -30,6 +31,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the trainer instance
+
"""
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
@@ -40,5 +42,6 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
Returns:
Dictionary of device stats
+
"""
raise NotImplementedError
diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py
index 9161bc92f8d0c..b1d621ad07785 100644
--- a/src/lightning/pytorch/accelerators/cuda.py
+++ b/src/lightning/pytorch/accelerators/cuda.py
@@ -69,6 +69,7 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
Raises:
FileNotFoundError:
If nvidia-smi installation not found
+
"""
return torch.cuda.memory_stats(device)
@@ -115,6 +116,7 @@ def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cov
Raises:
FileNotFoundError:
If nvidia-smi installation not found
+
"""
nvidia_smi_path = shutil.which("nvidia-smi")
if nvidia_smi_path is None:
diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py
index 03ba218604128..f25ed82f16566 100644
--- a/src/lightning/pytorch/accelerators/mps.py
+++ b/src/lightning/pytorch/accelerators/mps.py
@@ -28,6 +28,7 @@ class MPSAccelerator(Accelerator):
"""Accelerator for Metal Apple Silicon GPU devices.
.. warning:: Use of this accelerator beyond import and instantiation is experimental.
+
"""
def setup_device(self, device: torch.device) -> None:
diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py
index fe9c1261c9b39..e1ef449e79310 100644
--- a/src/lightning/pytorch/accelerators/xla.py
+++ b/src/lightning/pytorch/accelerators/xla.py
@@ -23,6 +23,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator):
"""Accelerator for XLA devices, normally TPUs.
.. warning:: Use of this accelerator beyond import and instantiation is experimental.
+
"""
def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
@@ -33,6 +34,7 @@ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]:
Returns:
A dictionary mapping the metrics (free memory and peak memory) to their values.
+
"""
import torch_xla.core.xla_model as xm
diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py
index 197433dc647cb..447b7dda9455d 100644
--- a/src/lightning/pytorch/callbacks/callback.py
+++ b/src/lightning/pytorch/callbacks/callback.py
@@ -26,6 +26,7 @@ class Callback:
r"""Abstract base class used to build new callbacks.
Subclass this class and override any of the relevant hooks
+
"""
@property
@@ -35,6 +36,7 @@ def state_key(self) -> str:
Used to store and retrieve a callback's state from the checkpoint dictionary by
``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1)
the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.
+
"""
return self.__class__.__qualname__
@@ -44,11 +46,12 @@ def _legacy_state_key(self) -> Type["Callback"]:
return type(self)
def _generate_state_key(self, **kwargs: Any) -> str:
- """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful
- for defining a :attr:`state_key`.
+ """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful for
+ defining a :attr:`state_key`.
Args:
**kwargs: A set of key-value pairs. Must be serializable to :class:`str`.
+
"""
return f"{self.__class__.__qualname__}{repr(kwargs)}"
@@ -83,6 +86,7 @@ def on_train_batch_end(
Note:
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
loss returned from ``training_step``.
+
"""
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -114,6 +118,7 @@ def on_train_epoch_end(self, trainer, pl_module):
pl_module.log("training_epoch_mean", epoch_mean)
# free up the memory
pl_module.training_step_outputs.clear()
+
"""
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -249,6 +254,7 @@ def on_save_checkpoint(
trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance.
pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance.
checkpoint: the checkpoint dictionary that will be saved.
+
"""
def on_load_checkpoint(
@@ -260,6 +266,7 @@ def on_load_checkpoint(
trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance.
pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance.
checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
+
"""
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None:
diff --git a/src/lightning/pytorch/callbacks/checkpoint.py b/src/lightning/pytorch/callbacks/checkpoint.py
index 301761049be74..3f241278b74c8 100644
--- a/src/lightning/pytorch/callbacks/checkpoint.py
+++ b/src/lightning/pytorch/callbacks/checkpoint.py
@@ -6,4 +6,5 @@ class Checkpoint(Callback):
Expert users may want to subclass it in case of writing custom :class:`~lightning.pytorch.callbacksCheckpoint`
callback, so that the trainer recognizes the custom class as a checkpointing callback.
+
"""
diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py
index d6996e408164a..23c5a2a5337da 100644
--- a/src/lightning/pytorch/callbacks/early_stopping.py
+++ b/src/lightning/pytorch/callbacks/early_stopping.py
@@ -14,6 +14,7 @@
r"""Early Stopping ^^^^^^^^^^^^^^
Monitor a metric and stop training when it stops improving.
+
"""
import logging
from typing import Any, Callable, Dict, Optional, Tuple
@@ -80,6 +81,7 @@ class EarlyStopping(Callback):
*monitor, mode*
Read more: :ref:`Persisting Callback State `
+
"""
mode_dict = {"min": torch.lt, "max": torch.gt}
diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py
index 9d19cc8f86a94..26386551bd88b 100644
--- a/src/lightning/pytorch/callbacks/finetuning.py
+++ b/src/lightning/pytorch/callbacks/finetuning.py
@@ -74,6 +74,7 @@ class BaseFinetuning(Callback):
... optimizer=optimizer,
... train_bn=True,
... )
+
"""
def __init__(self) -> None:
@@ -106,14 +107,15 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
- """This function is used to flatten a module or an iterable of modules into a list of its leaf modules
- (modules with no children) and parent modules that have parameters directly themselves.
+ """This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules
+ with no children) and parent modules that have parameters directly themselves.
Args:
modules: A given module or an iterable of modules
Returns:
List of modules
+
"""
if isinstance(modules, ModuleDict):
modules = modules.values()
@@ -142,6 +144,7 @@ def filter_params(
requires_grad: Whether to create a generator for trainable or non-trainable parameters.
Returns:
Generator
+
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
@@ -158,6 +161,7 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) ->
Args:
modules: A given module or an iterable of modules
+
"""
modules = BaseFinetuning.flatten_modules(modules)
for module in modules:
@@ -173,6 +177,7 @@ def freeze_module(module: Module) -> None:
Args:
module: A given module
+
"""
if isinstance(module, _BatchNorm):
module.track_running_stats = False
@@ -190,6 +195,7 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn:
Returns:
None
+
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
@@ -208,6 +214,7 @@ def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
Returns:
List of parameters not contained in this optimizer param groups
+
"""
out_params = []
removed_params = []
@@ -245,6 +252,7 @@ def unfreeze_and_add_param_group(
initial_denom_lr: If no lr is provided, the learning from the first param group will be used
and divided by `initial_denom_lr`.
train_bn: Whether to train the BatchNormalization layers.
+
"""
BaseFinetuning.make_trainable(modules)
params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr)
@@ -338,6 +346,7 @@ class BackboneFinetuning(BaseFinetuning):
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
+
"""
def __init__(
diff --git a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py
index 9c0b1a741f53d..1a18454b55291 100644
--- a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py
+++ b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py
@@ -60,6 +60,7 @@ class GradientAccumulationScheduler(Callback):
# because epoch (key) should be zero-indexed.
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
>>> trainer = Trainer(callbacks=[accumulator])
+
"""
def __init__(self, scheduling: Dict[int, int]):
@@ -99,8 +100,7 @@ def get_accumulate_grad_batches(self, epoch: int) -> int:
return accumulate_grad_batches
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
- """Performns a configuration validation before training starts and raises errors for incompatible
- settings."""
+ """Performns a configuration validation before training starts and raises errors for incompatible settings."""
if not pl_module.automatic_optimization:
raise RuntimeError(
diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py
index e062656313eab..45d7764a1b8ca 100644
--- a/src/lightning/pytorch/callbacks/lambda_function.py
+++ b/src/lightning/pytorch/callbacks/lambda_function.py
@@ -14,6 +14,7 @@
r"""Lambda Callback ^^^^^^^^^^^^^^^
Create a simple callback on the fly using lambda functions.
+
"""
from typing import Callable, Optional
@@ -32,6 +33,7 @@ class LambdaCallback(Callback):
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import LambdaCallback
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
+
"""
def __init__(
diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py
index d938db61d6d09..d823cf5d52010 100644
--- a/src/lightning/pytorch/callbacks/lr_monitor.py
+++ b/src/lightning/pytorch/callbacks/lr_monitor.py
@@ -84,6 +84,7 @@ def configure_optimizer(self):
)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
return [optimizer], [lr_scheduler]
+
"""
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None:
@@ -95,12 +96,13 @@ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool =
self.lrs: Dict[str, List[float]] = {}
def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
- """Called before training, determines unique names for all lr schedulers in the case of multiple of the
- same type or in the case of multiple parameter groups.
+ """Called before training, determines unique names for all lr schedulers in the case of multiple of the same
+ type or in the case of multiple parameter groups.
Raises:
MisconfigurationException:
If ``Trainer`` has no ``logger``.
+
"""
if not trainer.loggers:
raise MisconfigurationException(
diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py
index 8a75e9e137653..08518a863caa1 100644
--- a/src/lightning/pytorch/callbacks/model_checkpoint.py
+++ b/src/lightning/pytorch/callbacks/model_checkpoint.py
@@ -569,6 +569,7 @@ def format_checkpoint_name(
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
'step=0.ckpt'
+
"""
filename = filename or self.filename
filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name)
@@ -588,6 +589,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
3. The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers
The path gets extended with subdirectory "checkpoints".
+
"""
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py
index 4fc788ba09226..870d5a73a2a70 100644
--- a/src/lightning/pytorch/callbacks/model_summary.py
+++ b/src/lightning/pytorch/callbacks/model_summary.py
@@ -47,6 +47,7 @@ class ModelSummary(Callback):
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import ModelSummary
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
+
"""
def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None:
diff --git a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py
index 760e774a256e8..0b5a953cd7bf8 100644
--- a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py
+++ b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py
@@ -41,6 +41,7 @@ class OnExceptionCheckpoint(Checkpoint):
>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import OnExceptionCheckpoint
>>> trainer = Trainer(callbacks=[OnExceptionCheckpoint(".")])
+
"""
FILE_EXTENSION = ".ckpt"
diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py
index 0f19c771027d5..74ee0b85a7e3b 100644
--- a/src/lightning/pytorch/callbacks/prediction_writer.py
+++ b/src/lightning/pytorch/callbacks/prediction_writer.py
@@ -99,6 +99,7 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)
+
"""
def __init__(self, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch") -> None:
diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py
index 7a2b57be17bda..b20dda1cd1ec9 100644
--- a/src/lightning/pytorch/callbacks/progress/progress_bar.py
+++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py
@@ -19,9 +19,9 @@
class ProgressBar(Callback):
- r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that
- keeps track of the batch progress in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. You should
- implement your highly custom progress bars with this as the base class.
+ r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that keeps
+ track of the batch progress in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. You should implement your
+ highly custom progress bars with this as the base class.
Example::
@@ -42,6 +42,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
+
"""
def __init__(self) -> None:
@@ -80,6 +81,7 @@ def total_train_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training
dataloader is of infinite size.
+
"""
return self.trainer.num_training_batches
@@ -89,6 +91,7 @@ def total_val_batches_current_dataloader(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
+
"""
batches = self.trainer.num_sanity_val_batches if self.trainer.sanity_checking else self.trainer.num_val_batches
if isinstance(batches, list):
@@ -102,6 +105,7 @@ def total_test_batches_current_dataloader(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is
of infinite size.
+
"""
batches = self.trainer.num_test_batches
if isinstance(batches, list):
@@ -115,6 +119,7 @@ def total_predict_batches_current_dataloader(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
+
"""
assert self._current_eval_dataloader_idx is not None
return self.trainer.num_predict_batches[self._current_eval_dataloader_idx]
@@ -125,6 +130,7 @@ def total_val_batches(self) -> Union[int, float]:
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader
is of infinite size.
+
"""
if not self.trainer.fit_loop.epoch_loop._should_check_val_epoch():
return 0
@@ -152,6 +158,7 @@ def enable(self) -> None:
The :class:`~lightning.pytorch.trainer.trainer.Trainer` will call this in e.g. pre-training
routines like the :ref:`learning rate finder `.
to temporarily enable and disable the training progress bar.
+
"""
raise NotImplementedError
@@ -167,8 +174,8 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
def get_metrics(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
- r"""Combines progress bar metrics collected from the trainer with standard metrics from
- get_standard_metrics. Implement this to override the items displayed in the progress bar.
+ r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
+ Implement this to override the items displayed in the progress bar.
Here is an example of how to override the defaults:
@@ -182,6 +189,7 @@ def get_metrics(self, trainer, model):
Return:
Dictionary with the items to be displayed in the progress bar.
+
"""
standard_metrics = get_standard_metrics(trainer)
pbar_metrics = trainer.progress_bar_metrics
@@ -206,6 +214,7 @@ def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]:
Return:
Dictionary with the standard metrics to be displayed in the progress bar.
+
"""
items_dict: Dict[str, Union[int, str]] = {}
if trainer.loggers:
diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py
index b6934678c99d3..48aee9673e4c4 100644
--- a/src/lightning/pytorch/callbacks/progress/rich_progress.py
+++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py
@@ -191,6 +191,7 @@ class RichProgressBarTheme:
metrics: Style for the metrics
https://rich.readthedocs.io/en/stable/style.html
+
"""
description: Union[str, Style] = "white"
@@ -234,6 +235,7 @@ class RichProgressBar(ProgressBar):
PyCharm users will need to enable “emulate terminal” in output console option in
run/debug configuration to see styled output.
Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements
+
"""
def __init__(
diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
index d938b925717cc..fc36f81e8299b 100644
--- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
+++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
@@ -407,6 +407,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
"""The tqdm doesn't support inf/nan values.
We have to convert it to None.
+
"""
if x is None or math.isinf(x) or math.isnan(x):
return None
diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py
index 5cffe6171707e..83e430fa1a79c 100644
--- a/src/lightning/pytorch/callbacks/pruning.py
+++ b/src/lightning/pytorch/callbacks/pruning.py
@@ -74,8 +74,8 @@ def __init__(
verbose: int = 0,
prune_on_train_epoch_end: bool = True,
) -> None:
- """Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning
- networks parameters during training.
+ """Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning networks
+ parameters during training.
To learn more about pruning with PyTorch, please take a look at
`this tutorial `_.
@@ -152,6 +152,7 @@ def __init__(
if ``pruning_norm`` is not provided when ``"ln_structured"``,
if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or
if ``amount`` is none of ``int``, ``float`` and ``Callable``.
+
"""
self._use_global_unstructured = use_global_unstructured
@@ -235,6 +236,7 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable,
IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE,
pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
+
"""
pruning_meth = (
_PYTORCH_PRUNING_METHOD[pruning_fn]
@@ -259,6 +261,7 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
"""Removes pruning buffers from any pruned modules.
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/utils/prune.py#L1118-L1122
+
"""
for _, module in module.named_modules():
for k in list(module._forward_pre_hooks):
@@ -286,6 +289,7 @@ def apply_lottery_ticket_hypothesis(self) -> None:
This function implements the step 4.
The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta`
+
""" # noqa: E501
assert self._original_layers is not None
for d in self._original_layers.values():
diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py
index f68c98259b112..0a5d7d286ce12 100644
--- a/src/lightning/pytorch/callbacks/rich_model_summary.py
+++ b/src/lightning/pytorch/callbacks/rich_model_summary.py
@@ -23,8 +23,8 @@
class RichModelSummary(ModelSummary):
- r"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule` with `rich
- text formatting `_.
+ r"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule` with `rich text
+ formatting `_.
Install it with pip:
@@ -56,6 +56,7 @@ class RichModelSummary(ModelSummary):
Raises:
ModuleNotFoundError:
If required `rich` package is not installed on the device.
+
"""
def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None:
diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py
index dc6ae074a31b4..10a4d2fd195c6 100644
--- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py
+++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py
@@ -89,6 +89,7 @@ def __init__(
device: if provided, the averaged model will be stored on the ``device``.
When None is provided, it will infer the `device` from ``pl_module``.
(default: ``"cpu"``)
+
"""
err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py
index bb6245dbb00ea..36c0ee9719daa 100644
--- a/src/lightning/pytorch/callbacks/timer.py
+++ b/src/lightning/pytorch/callbacks/timer.py
@@ -33,8 +33,8 @@ class Interval(LightningEnum):
class Timer(Callback):
- """The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the
- Trainer if the given time limit for the training loop is reached.
+ """The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer
+ if the given time limit for the training loop is reached.
Args:
duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`,
@@ -69,6 +69,7 @@ class Timer(Callback):
timer.time_elapsed("train")
timer.start_time("validate")
timer.end_time("test")
+
"""
def __init__(
diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py
index 7c8c6f5afc6eb..95f105402b608 100644
--- a/src/lightning/pytorch/cli.py
+++ b/src/lightning/pytorch/cli.py
@@ -88,6 +88,7 @@ def __init__(
description: Description of the tool shown when running ``--help``.
env_prefix: Prefix for environment variables. Set ``default_env=True`` to enable env parsing.
default_env: Whether to parse environment variables.
+
"""
if not _JSONARGPARSE_SIGNATURES_AVAILABLE:
raise ModuleNotFoundError(f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}")
@@ -120,6 +121,7 @@ def add_lightning_class_args(
Returns:
A list with the names of the class arguments added.
+
"""
if callable(lightning_class) and not isinstance(lightning_class, type):
lightning_class = class_from_function(lightning_class)
@@ -155,6 +157,7 @@ def add_optimizer_args(
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses.
nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
+
"""
if isinstance(optimizer_class, tuple):
assert all(issubclass(o, Optimizer) for o in optimizer_class)
@@ -180,6 +183,7 @@ def add_lr_scheduler_args(
tuple to allow subclasses.
nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
+
"""
if isinstance(lr_scheduler_class, tuple):
assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class)
@@ -206,6 +210,7 @@ class SaveConfigCallback(Callback):
Raises:
RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run
+
"""
def __init__(
@@ -284,6 +289,7 @@ def save_config(self, trainer, pl_module, stage):
worry about ranks or race conditions. Since it only runs on rank zero, any collective call will make the
process hang waiting for a broadcast. If you need to make collective calls, implement the setup method
instead.
+
"""
@@ -306,8 +312,8 @@ def __init__(
run: bool = True,
auto_configure_optimizers: bool = True,
) -> None:
- """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
- are called / instantiated using a parsed configuration file and / or command line args.
+ """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
+ called / instantiated using a parsed configuration file and / or command line args.
Parsing of configuration from environment variables can be enabled by setting ``parser_kwargs={"default_env":
True}``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed
@@ -345,6 +351,7 @@ def __init__(
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
+
"""
self.save_config_callback = save_config_callback
self.save_config_kwargs = save_config_kwargs or {}
@@ -450,6 +457,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
Args:
parser: The parser object to which arguments can be added
+
"""
@staticmethod
@@ -533,6 +541,7 @@ def instantiate_trainer(self, **kwargs: Any) -> Trainer:
Args:
kwargs: Any custom trainer arguments.
+
"""
extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys]
trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs}
@@ -580,6 +589,7 @@ def configure_optimizers(
lightning_module: A reference to the model.
optimizer: The optimizer.
lr_scheduler: The learning rate scheduler (if used).
+
"""
if lr_scheduler is None:
return optimizer
@@ -591,8 +601,8 @@ def configure_optimizers(
return [optimizer], [lr_scheduler]
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
- """Overrides the model's :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method
- if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
+ """Overrides the model's :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method if
+ a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
if not self.auto_configure_optimizers:
return
@@ -725,6 +735,7 @@ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -
Returns:
The instantiated class object.
+
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py
index b556eefc30ef2..bda6dac711d06 100644
--- a/src/lightning/pytorch/core/datamodule.py
+++ b/src/lightning/pytorch/core/datamodule.py
@@ -28,8 +28,8 @@
class LightningDataModule(DataHooks, HyperparametersMixin):
- """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main
- advantage is consistent data splits, data preparation and transforms across models.
+ """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is
+ consistent data splits, data preparation and transforms across models.
Example::
@@ -62,6 +62,7 @@ def teardown(self):
# clean up state after the trainer stops, delete files...
# called on every process in DDP
...
+
"""
name: Optional[str] = None
@@ -98,6 +99,7 @@ def from_datasets(
data will be loaded in the main process. Number of CPUs available. This parameter gets forwarded to the
``__init__`` if the datamodule has such a name defined in its signature.
**datamodule_kwargs: Additional parameters that get passed down to the datamodule's ``__init__``.
+
"""
def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
@@ -142,6 +144,7 @@ def state_dict(self) -> Dict[str, Any]:
Returns:
A dictionary containing datamodule state.
+
"""
return {}
@@ -150,6 +153,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Args:
state_dict: the datamodule state returned by ``state_dict``.
+
"""
pass
diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py
index 5c491c1aefaef..e923faa2a4f3c 100644
--- a/src/lightning/pytorch/core/hooks.py
+++ b/src/lightning/pytorch/core/hooks.py
@@ -31,12 +31,14 @@ def on_fit_start(self) -> None:
"""Called at the very beginning of fit.
If on DDP it is called on every process
+
"""
def on_fit_end(self) -> None:
"""Called at the very end of fit.
If on DDP it is called on every process
+
"""
def on_train_start(self) -> None:
@@ -71,6 +73,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int) -> Optional[int]:
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
+
"""
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
@@ -80,6 +83,7 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -
outputs: The outputs of training_step(x)
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
+
"""
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
@@ -89,6 +93,7 @@ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
batch: The batched data as it is returned by the validation DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
+
"""
def on_validation_batch_end(
@@ -101,6 +106,7 @@ def on_validation_batch_end(
batch: The batched data as it is returned by the validation DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
+
"""
def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
@@ -110,6 +116,7 @@ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int =
batch: The batched data as it is returned by the test DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
+
"""
def on_test_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
@@ -120,6 +127,7 @@ def on_test_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, da
batch: The batched data as it is returned by the test DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
+
"""
def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
@@ -129,6 +137,7 @@ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int
batch: The batched data as it is returned by the test DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
+
"""
def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
@@ -139,6 +148,7 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in
batch: The batched data as it is returned by the prediction DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
+
"""
def on_validation_model_eval(self) -> None:
@@ -188,6 +198,7 @@ def on_train_epoch_end(self):
self.log("training_epoch_mean", epoch_mean)
# free up the memory
self.training_step_outputs.clear()
+
"""
def on_validation_epoch_start(self) -> None:
@@ -274,6 +285,7 @@ def configure_sharded_model(self) -> None:
"""Deprecated.
Use :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` instead.
+
"""
def configure_model(self) -> None:
@@ -286,6 +298,7 @@ def configure_model(self) -> None:
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
implementation of this hook is idempotent.
+
"""
@@ -308,8 +321,8 @@ def __init__(self) -> None:
def prepare_data(self) -> None:
"""Use this to download and prepare data. Downloading and saving data with multiple processes (distributed
- settings) will result in corrupted data. Lightning ensures this method is called only within a single
- process, so you can safely add your downloading logic within.
+ settings) will result in corrupted data. Lightning ensures this method is called only within a single process,
+ so you can safely add your downloading logic within.
.. warning:: DO NOT set state to the model (use ``setup`` instead)
since this is NOT called on every device
@@ -359,12 +372,13 @@ def __init__(self):
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
+
"""
def setup(self, stage: str) -> None:
- """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when
- you need to build models dynamically or adjust something about them. This hook is called on every process
- when using DDP.
+ """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you
+ need to build models dynamically or adjust something about them. This hook is called on every process when
+ using DDP.
Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
@@ -385,6 +399,7 @@ def prepare_data(self):
def setup(self, stage):
data = load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
+
"""
def teardown(self, stage: str) -> None:
@@ -392,6 +407,7 @@ def teardown(self, stage: str) -> None:
Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
+
"""
def train_dataloader(self) -> TRAIN_DATALOADERS:
@@ -419,6 +435,7 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
Note:
Lightning tries to add the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
+
"""
raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer")
@@ -448,6 +465,7 @@ def test_dataloader(self) -> EVAL_DATALOADERS:
Note:
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
this method.
+
"""
raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer")
@@ -474,6 +492,7 @@ def val_dataloader(self) -> EVAL_DATALOADERS:
Note:
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
implement this method.
+
"""
raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer")
@@ -494,14 +513,15 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
Return:
A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples.
+
"""
raise MisconfigurationException(
"`predict_dataloader` must be implemented to be used with the Lightning Trainer"
)
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
- """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom
- data structure.
+ """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data
+ structure.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
@@ -548,6 +568,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
See Also:
- :meth:`move_data_to_device`
- :meth:`apply_to_collection`
+
"""
return move_data_to_device(batch, device)
@@ -575,6 +596,7 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
See Also:
- :meth:`on_after_batch_transfer`
- :meth:`transfer_batch_to_device`
+
"""
return batch
@@ -606,6 +628,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
See Also:
- :meth:`on_before_batch_transfer`
- :meth:`transfer_batch_to_device`
+
"""
return batch
@@ -614,8 +637,8 @@ class CheckpointHooks:
"""Hooks to be used with Checkpointing."""
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
- r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this
- is your chance to restore this.
+ r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is
+ your chance to restore this.
Args:
checkpoint: Loaded checkpoint
@@ -629,11 +652,12 @@ def on_load_checkpoint(self, checkpoint):
Note:
Lightning auto-restores global step, epoch, and train state including amp scaling.
There is no need for you to restore anything regarding training.
+
"""
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
- r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want
- to save.
+ r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to
+ save.
Args:
checkpoint: The full checkpoint dictionary before it gets dumped to a file.
@@ -649,4 +673,5 @@ def on_save_checkpoint(self, checkpoint):
Lightning saves all aspects of training (epoch, global step, etc...)
including amp scaling.
There is no need for you to store anything about training.
+
"""
diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py
index d30caeda6b59c..ca6ad172e0725 100644
--- a/src/lightning/pytorch/core/mixins/hparams_mixin.py
+++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py
@@ -132,11 +132,12 @@ def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[Mutable
@property
def hparams(self) -> Union[AttributeDict, MutableMapping]:
- """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
- For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
+ """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For
+ the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
Returns:
Mutable hyperparameters dictionary
+
"""
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()
@@ -149,6 +150,7 @@ def hparams_initial(self) -> AttributeDict:
Returns:
AttributeDict: immutable initial hyperparameters
+
"""
if not hasattr(self, "_hparams_initial"):
return AttributeDict()
diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py
index b5ab6604c3c4d..a4237a7352682 100644
--- a/src/lightning/pytorch/core/module.py
+++ b/src/lightning/pytorch/core/module.py
@@ -152,6 +152,7 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
Returns:
A single optimizer, or a list of optimizers in case multiple ones are present.
+
"""
if self._fabric:
opts: MODULE_OPTIMIZERS = self._fabric_optimizers
@@ -171,12 +172,12 @@ def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
return opts
def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]:
- """Returns the learning rate scheduler(s) that are being used during training. Useful for manual
- optimization.
+ """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization.
Returns:
A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no
schedulers were returned in :meth:`configure_optimizers`.
+
"""
if not self.trainer.lr_scheduler_configs:
return None
@@ -224,8 +225,8 @@ def fabric(self, fabric: Optional["lf.Fabric"]) -> None:
@property
def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]:
- """The example input array is a specification of what the module can consume in the :meth:`forward` method.
- The return type is interpreted as follows:
+ """The example input array is a specification of what the module can consume in the :meth:`forward` method. The
+ return type is interpreted as follows:
- Single tensor: It is assumed the model takes a single argument, i.e.,
``model.forward(model.example_input_array)``
@@ -233,6 +234,7 @@ def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]:
``model.forward(*model.example_input_array)``
- Dict: The input array represents named keyword arguments, i.e.,
``model.forward(**model.example_input_array)``
+
"""
return self._example_input_array
@@ -250,6 +252,7 @@ def global_step(self) -> int:
"""Total training batches seen across all epochs.
If no Trainer is attached, this propery is 0.
+
"""
return self.trainer.global_step if self._trainer else 0
@@ -333,6 +336,7 @@ def print(self, *args: Any, **kwargs: Any) -> None:
def forward(self, x):
self.print(x, 'in forward')
+
"""
if self.trainer.is_global_zero:
progress_bar = self.trainer.progress_bar_callback
@@ -389,6 +393,7 @@ def log(
:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
+
"""
if self._fabric is not None:
self._log_dict_through_fabric(dictionary={name: value}, logger=logger)
@@ -551,6 +556,7 @@ def log_dict(
but some data structures might need to explicitly provide it.
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
would produce a deadlock as not all processes would perform this log call.
+
"""
if self._fabric is not None:
return self._log_dict_through_fabric(dictionary=dictionary, logger=logger)
@@ -630,6 +636,7 @@ def all_gather(
Return:
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape.
+
"""
group = group if group is not None else torch.distributed.group.WORLD
all_gather = self.trainer.strategy.all_gather
@@ -645,6 +652,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
Return:
Your model's output
+
"""
return super().forward(*args, **kwargs)
@@ -698,12 +706,13 @@ def training_step(self, batch, batch_idx):
Note:
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
normalized by ``accumulate_grad_batches`` internally.
+
"""
rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")
def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
- r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples
- or calculate anything of interest like accuracy.
+ r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or
+ calculate anything of interest like accuracy.
Args:
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
@@ -767,6 +776,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
When the :meth:`validation_step` is called, the model has been put in eval mode
and PyTorch gradients have been disabled. At the end of validation,
the model goes back to training mode and gradients are enabled.
+
"""
def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -835,11 +845,12 @@ def test_step(self, batch, batch_idx, dataloader_idx=0):
When the :meth:`test_step` is called, the model has been put in eval mode and
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
to training mode and gradients are enabled.
+
"""
def predict_step(self, *args: Any, **kwargs: Any) -> Any:
- """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it
- calls :meth:`~lightning.pytorch.core.module.LightningModule.forward`. Override to add any processing logic.
+ """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls
+ :meth:`~lightning.pytorch.core.module.LightningModule.forward`. Override to add any processing logic.
The :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` is used
to scale inference on multi-devices.
@@ -871,6 +882,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
+
"""
# For backwards compatibility
batch = kwargs.get("batch", args[0])
@@ -897,9 +909,9 @@ def configure_callbacks(self):
return []
def configure_optimizers(self) -> Any:
- r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need
- one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only
- works in the manual optimization mode.
+ r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one.
+ But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in
+ the manual optimization mode.
Return:
Any of these 6 options.
@@ -994,6 +1006,7 @@ def configure_optimizers(self):
- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them
yourself.
- If you need to control how often the optimizer steps, override the :meth:`optimizer_step` hook.
+
"""
rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
@@ -1017,6 +1030,7 @@ def training_step(...):
loss: The tensor on which to compute gradients. Must have a graph attached.
*args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward`
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
+
"""
if self._fabric:
self._fabric.backward(loss, *args, **kwargs)
@@ -1025,8 +1039,8 @@ def training_step(...):
self.trainer.strategy.backward(loss, None, *args, **kwargs)
def backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
- """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
- own implementation if you need to.
+ """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own
+ implementation if you need to.
Args:
loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here
@@ -1036,6 +1050,7 @@ def backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
def backward(self, loss):
loss.backward()
+
"""
if self._fabric:
self._fabric.backward(loss, *args, **kwargs)
@@ -1043,13 +1058,14 @@ def backward(self, loss):
loss.backward(*args, **kwargs)
def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:
- """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step
- to prevent dangling gradients in multiple-optimizer setup.
+ """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
+ prevent dangling gradients in multiple-optimizer setup.
It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.
Args:
optimizer: The optimizer to toggle.
+
"""
# Iterate over all optimizer parameters to preserve their `requires_grad` information
# in case these are pre-defined during `configure_optimizers`
@@ -1075,6 +1091,7 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) ->
Args:
optimizer: The optimizer to untoggle.
+
"""
for opt in self.trainer.optimizers:
if not (opt is optimizer or (isinstance(optimizer, LightningOptimizer) and opt is optimizer.optimizer)):
@@ -1106,6 +1123,7 @@ def clip_gradients(
gradient_clip_val: The value at which to clip gradients.
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
+
"""
if self.fabric is not None:
@@ -1177,6 +1195,7 @@ def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_cli
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm
)
+
"""
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
@@ -1219,8 +1238,8 @@ def optimizer_step(
optimizer: Union[Optimizer, LightningOptimizer],
optimizer_closure: Optional[Callable[[], Any]] = None,
) -> None:
- r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer`
- calls the optimizer.
+ r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
+ the optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example.
This method (and ``zero_grad()``) won't be called during the accumulation phase when
@@ -1249,6 +1268,7 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.learning_rate
+
"""
optimizer.step(closure=optimizer_closure)
@@ -1281,6 +1301,7 @@ def freeze(self) -> None:
model = MyLightningModule(...)
model.freeze()
+
"""
for param in self.parameters():
param.requires_grad = False
@@ -1294,6 +1315,7 @@ def unfreeze(self) -> None:
model = MyLightningModule(...)
model.unfreeze()
+
"""
for param in self.parameters():
param.requires_grad = True
@@ -1329,6 +1351,7 @@ def forward(self, x):
model = SimpleModel()
input_sample = torch.randn(1, 64)
model.to_onnx("export.onnx", input_sample, export_params=True)
+
"""
if _TORCH_GREATER_EQUAL_2_0 and not _ONNX_AVAILABLE:
raise ModuleNotFoundError(
@@ -1533,6 +1556,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
"""Adds ShardedTensor state dict hooks if ShardedTensors are supported.
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
+
"""
if _TORCH_GREATER_EQUAL_2_1:
# ShardedTensor is deprecated in favor of DistributedTensor
diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py
index afd2ca01a98d5..e90ff0be84a6b 100644
--- a/src/lightning/pytorch/core/optimizer.py
+++ b/src/lightning/pytorch/core/optimizer.py
@@ -34,8 +34,8 @@ def do_nothing_closure() -> None:
class LightningOptimizer:
- """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic
- across accelerators, AMP, accumulate_grad_batches."""
+ """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across
+ accelerators, AMP, accumulate_grad_batches."""
def __init__(self, optimizer: Optimizer):
# copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has
@@ -73,6 +73,7 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
When performing gradient accumulation, there is no need to perform grad synchronization
during the accumulation phase.
Setting `sync_grad` to False will block this synchronization and improve performance.
+
"""
# local import here to avoid circular import
from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
@@ -144,6 +145,7 @@ def closure_dis():
with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
+
"""
self._on_before_step()
@@ -237,8 +239,7 @@ def _configure_optimizers(
def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]:
- """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic
- optimization."""
+ """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization."""
lr_scheduler_configs = []
for scheduler in schedulers:
if isinstance(scheduler, dict):
diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py
index 9b32058a81ac9..1fb5e0d46bfe9 100644
--- a/src/lightning/pytorch/core/saving.py
+++ b/src/lightning/pytorch/core/saving.py
@@ -215,6 +215,7 @@ def update_hparams(hparams: dict, updates: dict) -> None:
Args:
hparams: the original params and also target object
updates: new params to be used as update
+
"""
for k, v in updates.items():
# if missing, add the key
@@ -240,6 +241,7 @@ def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]:
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_csv)
+
"""
fs = get_filesystem(tags_csv)
if not fs.exists(tags_csv):
@@ -282,6 +284,7 @@ def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Di
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_yaml)
+
"""
fs = get_filesystem(config_yaml)
if not fs.exists(config_yaml):
diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py
index 918e6c973734f..3dd7bd8b1afc8 100644
--- a/src/lightning/pytorch/demos/boring_classes.py
+++ b/src/lightning/pytorch/demos/boring_classes.py
@@ -105,6 +105,7 @@ class BoringModel(LightningModule):
class TestModel(BoringModel):
def training_step(self, ...):
... # do your own thing
+
"""
def __init__(self) -> None:
diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py
index 63c36e108a4b0..5f4528a9398d3 100644
--- a/src/lightning/pytorch/demos/mnist_datamodule.py
+++ b/src/lightning/pytorch/demos/mnist_datamodule.py
@@ -149,6 +149,7 @@ class MNISTDataModule(LightningDataModule):
>>> MNISTDataModule() # doctest: +ELLIPSIS
<...mnist_datamodule.MNISTDataModule object at ...>
+
"""
name = "mnist"
diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py
index c8e2d6bb88e1c..13e220759f549 100644
--- a/src/lightning/pytorch/demos/transformer.py
+++ b/src/lightning/pytorch/demos/transformer.py
@@ -2,6 +2,7 @@
Code is adapted from the PyTorch examples at
https://github.com/pytorch/examples/blob/main/word_language_model
+
"""
import math
import os
diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py
index e2a095e2e6062..27afe86730065 100644
--- a/src/lightning/pytorch/loggers/comet.py
+++ b/src/lightning/pytorch/loggers/comet.py
@@ -204,6 +204,7 @@ def __init__(self, *args, **kwarg):
If required Comet package is not installed on the device.
MisconfigurationException:
If neither ``api_key`` nor ``save_dir`` are passed as arguments.
+
"""
LOGGER_JOIN_CHAR = "-"
@@ -357,6 +358,7 @@ def save_dir(self) -> Optional[str]:
Returns:
The path to the save directory.
+
"""
return self._save_dir
@@ -366,6 +368,7 @@ def name(self) -> str:
Returns:
The project name if it is specified, else "comet-default".
+
"""
# Don't create an experiment if we don't have one
if self._experiment is not None and self._experiment.project_name is not None:
@@ -389,6 +392,7 @@ def version(self) -> str:
4. future experiment key.
If none are present generates a new guid.
+
"""
# Don't create an experiment if we don't have one
if self._experiment is not None:
diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py
index beeeb905c3dd6..18cd83960d512 100644
--- a/src/lightning/pytorch/loggers/csv_logs.py
+++ b/src/lightning/pytorch/loggers/csv_logs.py
@@ -45,6 +45,7 @@ class ExperimentWriter(_FabricExperimentWriter):
Args:
log_dir: Directory for the experiment logs
+
"""
NAME_HPARAMS_FILE = "hparams.yaml"
@@ -82,6 +83,7 @@ class CSVLogger(Logger, FabricCSVLogger):
directory for existing versions, then automatically assigns the next available version.
prefix: A string to put at the beginning of metric keys.
flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).
+
"""
LOGGER_JOIN_CHAR = "-"
@@ -109,6 +111,7 @@ def root_dir(self) -> str:
If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will
be saved in "save_dir/version"
+
"""
return os.path.join(self.save_dir, self.name)
@@ -118,6 +121,7 @@ def log_dir(self) -> str:
By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the
constructor's version parameter instead of ``None`` or an int.
+
"""
# create a pseudo standard path
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
@@ -129,6 +133,7 @@ def save_dir(self) -> str:
Returns:
The path to current directory where logs are saved.
+
"""
return self._save_dir
diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py
index 52a51ab8eb74a..59ff16ac99cc5 100644
--- a/src/lightning/pytorch/loggers/logger.py
+++ b/src/lightning/pytorch/loggers/logger.py
@@ -36,6 +36,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
Args:
checkpoint_callback: the model checkpoint callback instance
+
"""
pass
@@ -50,6 +51,7 @@ class DummyLogger(Logger):
"""Dummy logger for internal use.
It is useful if we want to disable user's logger for a feature, but still ensure that user code can run
+
"""
def __init__(self) -> None:
@@ -96,8 +98,7 @@ def merge_dicts( # pragma: no cover
agg_key_funcs: Optional[Mapping] = None,
default_func: Callable[[Sequence[float]], float] = np.mean,
) -> Dict:
- """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given
- function.
+ """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function.
Args:
dicts:
@@ -128,6 +129,7 @@ def merge_dicts( # pragma: no cover
'c': 1,
'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}},
'v': 2.3}
+
"""
agg_key_funcs = agg_key_funcs or {}
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py
index 56aca09013200..9834875da2de8 100644
--- a/src/lightning/pytorch/loggers/mlflow.py
+++ b/src/lightning/pytorch/loggers/mlflow.py
@@ -130,6 +130,7 @@ def any_lightning_module_function_or_hook(self):
Raises:
ModuleNotFoundError:
If required MLFlow package is not installed on the device.
+
"""
LOGGER_JOIN_CHAR = "-"
@@ -218,6 +219,7 @@ def run_id(self) -> Optional[str]:
Returns:
The run id.
+
"""
_ = self.experiment
return self._run_id
@@ -228,6 +230,7 @@ def experiment_id(self) -> Optional[str]:
Returns:
The experiment id.
+
"""
_ = self.experiment
return self._experiment_id
@@ -295,6 +298,7 @@ def save_dir(self) -> Optional[str]:
Return:
Local path to the root experiment directory if the tracking uri is local.
Otherwise returns `None`.
+
"""
if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)
@@ -306,6 +310,7 @@ def name(self) -> Optional[str]:
Returns:
The experiment id.
+
"""
return self.experiment_id
@@ -315,6 +320,7 @@ def version(self) -> Optional[str]:
Returns:
The run id.
+
"""
return self.run_id
diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py
index 3e5f4f6566211..b6e8b701721e8 100644
--- a/src/lightning/pytorch/loggers/neptune.py
+++ b/src/lightning/pytorch/loggers/neptune.py
@@ -223,6 +223,7 @@ def any_lightning_module_function_or_hook(self):
If the required Neptune package is not installed.
ValueError:
If an argument passed to the logger's constructor is incorrect.
+
"""
LOGGER_JOIN_CHAR = "/"
@@ -413,6 +414,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: #
)
neptune_logger.log_hyperparams(PARAMS)
+
"""
params = _convert_params(params)
params = _sanitize_callable_params(params)
@@ -431,6 +433,7 @@ def log_metrics( # type: ignore[override]
Args:
metrics: Dictionary with metric names as keys and measured quantities as values.
step: Step number at which the metrics should be recorded, currently ignored.
+
"""
if rank_zero_only.rank != 0:
raise ValueError("run tried to log from global_rank != 0")
@@ -476,6 +479,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
Args:
checkpoint_callback: the model checkpoint callback instance
+
"""
if not self._log_model_checkpoints:
return
@@ -560,5 +564,6 @@ def version(self) -> Optional[str]:
"""Return the experiment version.
It's Neptune Run's short_id
+
"""
return self._run_short_id
diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py
index c99ca2ad02e2a..567f80b4a81a2 100644
--- a/src/lightning/pytorch/loggers/tensorboard.py
+++ b/src/lightning/pytorch/loggers/tensorboard.py
@@ -96,6 +96,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
>>> tbl.log_metrics({"acc": 0.9})
>>> tbl.finalize("success")
>>> shutil.rmtree(tmp)
+
"""
NAME_HPARAMS_FILE = "hparams.yaml"
@@ -133,6 +134,7 @@ def root_dir(self) -> str:
If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will
be saved in "save_dir/version"
+
"""
return os.path.join(super().root_dir, self.name)
@@ -142,6 +144,7 @@ def log_dir(self) -> str:
By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the
constructor's version parameter instead of ``None`` or an int.
+
"""
# create a pseudo standard path ala test-tube
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
@@ -158,6 +161,7 @@ def save_dir(self) -> str:
Returns:
The local path to the save directory where the TensorBoard experiments are saved.
+
"""
return self._root_dir
@@ -166,12 +170,13 @@ def log_hyperparams( # type: ignore[override]
self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None
) -> None:
"""Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the
- hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs
- to display the new ones with hyperparameters.
+ hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to
+ display the new ones with hyperparameters.
Args:
params: a dictionary-like container with the hyperparameters
metrics: Dictionary with metric names as keys and measured quantities as values
+
"""
params = _convert_params(params)
@@ -233,6 +238,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
Args:
checkpoint_callback: the model checkpoint callback instance
+
"""
pass
diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py
index ddc9e24749318..588826aa4135f 100644
--- a/src/lightning/pytorch/loggers/utilities.py
+++ b/src/lightning/pytorch/loggers/utilities.py
@@ -35,6 +35,7 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict)
Args:
checkpoint_callback: Checkpoint callback reference.
logged_model_time: dictionary containing the logged model times.
+
"""
# get checkpoints to be saved with associated score
checkpoints = {}
diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py
index df676230ed5e7..9f9738a753efb 100644
--- a/src/lightning/pytorch/loggers/wandb.py
+++ b/src/lightning/pytorch/loggers/wandb.py
@@ -279,6 +279,7 @@ def any_lightning_module_function_or_hook(self):
If required WandB package is not installed on the device.
MisconfigurationException:
If both ``log_model`` and ``offline`` is set to ``True``.
+
"""
LOGGER_JOIN_CHAR = "-"
@@ -436,6 +437,7 @@ def log_table(
"""Log a Table containing any object type (text, image, audio, video, molecule, html, etc).
Can be defined either with `columns` and `data` or with `dataframe`.
+
"""
metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)}
@@ -453,6 +455,7 @@ def log_text(
"""Log text as a Table.
Can be defined either with `columns` and `data` or with `dataframe`.
+
"""
self.log_table(key, columns, data, dataframe, step)
@@ -462,6 +465,7 @@ def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **k
"""Log images (tensors, numpy arrays, PIL Images or file paths).
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
+
"""
if not isinstance(images, list):
raise TypeError(f'Expected a list as "images", found {type(images)}')
@@ -479,6 +483,7 @@ def save_dir(self) -> Optional[str]:
Returns:
The path to the save directory.
+
"""
return self._save_dir
@@ -489,6 +494,7 @@ def name(self) -> Optional[str]:
Returns:
The name of the project the current experiment belongs to. This name is not the same as `wandb.Run`'s
name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead.
+
"""
return self._project
@@ -498,6 +504,7 @@ def version(self) -> Optional[str]:
Returns:
The id of the experiment if the experiment exists else the id given to the constructor.
+
"""
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
@@ -527,6 +534,7 @@ def download_artifact(
Returns:
The path to the downloaded artifact.
+
"""
if wandb.run is not None and use_artifact:
artifact = wandb.run.use_artifact(artifact)
@@ -546,6 +554,7 @@ def use_artifact(self, artifact: str, artifact_type: Optional[str] = None) -> "w
Returns:
wandb Artifact object for the artifact.
+
"""
return self.experiment.use_artifact(artifact, type=artifact_type)
diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py
index 14f39e5c0fa08..38a6b8803f4dd 100644
--- a/src/lightning/pytorch/loops/evaluation_loop.py
+++ b/src/lightning/pytorch/loops/evaluation_loop.py
@@ -89,6 +89,7 @@ def max_batches(self) -> Union[int, float, List[Union[int, float]]]:
"""In "sequential" mode, the max number of batches to run per dataloader.
Otherwise, the max batches to run.
+
"""
max_batches = self._max_batches
if not self.trainer.sanity_checking:
@@ -377,6 +378,7 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> N
batch: The current batch to run through the step.
batch_idx: The index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
+
"""
trainer = self.trainer
@@ -431,6 +433,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int
Returns:
the dictionary containing all the keyboard arguments for the step
+
"""
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])
if dataloader_idx is not None:
diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py
index 8df73c891c920..9c526c9b6bcb2 100644
--- a/src/lightning/pytorch/loops/fetchers.py
+++ b/src/lightning/pytorch/loops/fetchers.py
@@ -78,6 +78,7 @@ class _PrefetchDataFetcher(_DataFetcher):
Args:
prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track
whether a batch is the last one (available with :attr:`self.done`) when the length is not available.
+
"""
def __init__(self, prefetch_batches: int = 1) -> None:
diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py
index 16c11ba45c677..3f4047d8f7ffd 100644
--- a/src/lightning/pytorch/loops/fit_loop.py
+++ b/src/lightning/pytorch/loops/fit_loop.py
@@ -69,6 +69,7 @@ class _FitLoop(_Loop):
Args:
min_epochs: The minimum number of epochs
max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
+
"""
def __init__(
diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py
index 2a3bf1dfc4a9b..56d520800c447 100644
--- a/src/lightning/pytorch/loops/loop.py
+++ b/src/lightning/pytorch/loops/loop.py
@@ -42,6 +42,7 @@ def on_save_checkpoint(self) -> Dict:
Returns:
The current loop state.
+
"""
return {}
@@ -55,6 +56,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
destination: An existing dictionary to update with this loop's state. By default a new dictionary
is returned.
prefix: A prefix for each key in the state dictionary
+
"""
if destination is None:
destination = {}
diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py
index f86de295d80e4..26c5d404272b3 100644
--- a/src/lightning/pytorch/loops/optimization/automatic.py
+++ b/src/lightning/pytorch/loops/optimization/automatic.py
@@ -40,6 +40,7 @@ class ClosureResult(OutputResult):
closure_loss: The loss with a graph attached.
loss: A detached copy of the closure loss.
extra: Any keys other than the loss returned.
+
"""
closure_loss: Optional[Tensor]
@@ -158,6 +159,7 @@ def run(self, optimizer: Optimizer, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
Args:
kwargs: the kwargs passed down to the hooks
optimizer: the optimizer
+
"""
closure = self._make_closure(kwargs, optimizer)
@@ -203,6 +205,7 @@ def _make_zero_grad_fn(self, batch_idx: int, optimizer: Optimizer) -> Optional[C
"""Build a `zero_grad` function that zeroes the gradients before back-propagation.
Returns ``None`` in the case backward needs to be skipped.
+
"""
if self._skip_backward:
return None
@@ -218,10 +221,11 @@ def zero_grad_fn() -> None:
return zero_grad_fn
def _make_backward_fn(self, optimizer: Optimizer) -> Optional[Callable[[Tensor], None]]:
- """Build a `backward` function that handles back-propagation through the output produced by the
- `training_step` function.
+ """Build a `backward` function that handles back-propagation through the output produced by the `training_step`
+ function.
Returns ``None`` in the case backward needs to be skipped.
+
"""
if self._skip_backward:
return None
@@ -242,6 +246,7 @@ def _optimizer_step(
batch_idx: the index of the current batch
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default, called by the optimizer (if possible)
+
"""
trainer = self.trainer
@@ -285,6 +290,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer)
Args:
batch_idx: the index of the current batch
optimizer: the current optimizer
+
"""
trainer = self.trainer
call._call_lightning_module_hook(trainer, "optimizer_zero_grad", trainer.current_epoch, batch_idx, optimizer)
@@ -298,6 +304,7 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
Returns:
A ``ClosureResult`` containing the training step output.
+
"""
trainer = self.trainer
diff --git a/src/lightning/pytorch/loops/optimization/closure.py b/src/lightning/pytorch/loops/optimization/closure.py
index ec85a96e54042..4b550166b721e 100644
--- a/src/lightning/pytorch/loops/optimization/closure.py
+++ b/src/lightning/pytorch/loops/optimization/closure.py
@@ -35,6 +35,7 @@ class AbstractClosure(ABC, Generic[T]):
This class provides a simple abstraction making the instance of this class callable like a function while capturing
the closure result and caching it.
+
"""
def __init__(self) -> None:
@@ -46,6 +47,7 @@ def consume_result(self) -> T:
Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long
as necessary.
+
"""
if self._result is None:
raise MisconfigurationException(
diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py
index 79dd4360dc169..01998ae5ff9da 100644
--- a/src/lightning/pytorch/loops/optimization/manual.py
+++ b/src/lightning/pytorch/loops/optimization/manual.py
@@ -72,6 +72,7 @@ class _ManualOptimization(_Loop):
This loop is a trivial case because it performs only a single iteration (calling directly into the module's
:meth:`~lightning.pytorch.core.module.LightningModule.training_step`) and passing through the output(s).
+
"""
output_result_cls = ManualResult
@@ -102,6 +103,7 @@ def advance(self, kwargs: OrderedDict) -> None:
Args:
kwargs: The kwargs passed down to the hooks.
+
"""
trainer = self.trainer
diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py
index 227d3246bba3c..6df419f412173 100644
--- a/src/lightning/pytorch/loops/prediction_loop.py
+++ b/src/lightning/pytorch/loops/prediction_loop.py
@@ -208,6 +208,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
batch: the current batch to run the prediction on
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
+
"""
trainer = self.trainer
batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx)
diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py
index 788f97bbc6b3f..8ff12ba3781d1 100644
--- a/src/lightning/pytorch/loops/progress.py
+++ b/src/lightning/pytorch/loops/progress.py
@@ -45,6 +45,7 @@ class _ReadyCompletedTracker(_BaseProgress):
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
+
"""
ready: int = 0
@@ -60,6 +61,7 @@ def reset_on_restart(self) -> None:
If there is a failure before all attributes are increased, restore the attributes to the last fully completed
value.
+
"""
self.ready = self.completed
@@ -74,6 +76,7 @@ class _StartedTracker(_ReadyCompletedTracker):
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
+
"""
started: int = 0
@@ -98,6 +101,7 @@ class _ProcessedTracker(_StartedTracker):
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
+
"""
processed: int = 0
@@ -118,6 +122,7 @@ class _Progress(_BaseProgress):
Args:
total: Intended to track the total progress of an event.
current: Intended to track the current progress of an event.
+
"""
total: _ReadyCompletedTracker = field(default_factory=_ProcessedTracker)
@@ -177,6 +182,7 @@ class _BatchProgress(_Progress):
total: Tracks the total batch progress.
current: Tracks the current batch progress.
is_last_batch: Whether the batch is the last one. This is useful for iterable datasets.
+
"""
is_last_batch: bool = False
@@ -203,6 +209,7 @@ class _SchedulerProgress(_Progress):
Args:
total: Tracks the total scheduler progress.
current: Tracks the current scheduler progress.
+
"""
total: _ReadyCompletedTracker = field(default_factory=_ReadyCompletedTracker)
@@ -216,6 +223,7 @@ class _OptimizerProgress(_BaseProgress):
Args:
step: Tracks ``optimizer.step`` calls.
zero_grad: Tracks ``optimizer.zero_grad`` calls.
+
"""
step: _Progress = field(default_factory=lambda: _Progress.from_defaults(_ReadyCompletedTracker))
@@ -244,6 +252,7 @@ class _OptimizationProgress(_BaseProgress):
Args:
optimizer: Tracks optimizer progress.
+
"""
optimizer: _OptimizerProgress = field(default_factory=_OptimizerProgress)
diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py
index b9d205070e0e3..46835f4677796 100644
--- a/src/lightning/pytorch/loops/training_epoch_loop.py
+++ b/src/lightning/pytorch/loops/training_epoch_loop.py
@@ -181,6 +181,7 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
Raises:
StopIteration: When the epoch is canceled by the user returning -1
+
"""
if self.restarting and self._should_check_val_fx():
# skip training and run validation in `on_advance_end`
@@ -288,8 +289,7 @@ def _num_ready_batches_reached(self) -> bool:
return epoch_finished_on_ready or self.batch_progress.is_last_batch
def _should_accumulate(self) -> bool:
- """Checks if the optimizer step should be performed or gradients should be accumulated for the current
- step."""
+ """Checks if the optimizer step should be performed or gradients should be accumulated for the current step."""
accumulation_done = self._accumulated_batches_reached()
# Lightning steps on the final batch
is_final_batch = self._num_ready_batches_reached()
@@ -312,6 +312,7 @@ def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool)
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
so they have to be updated separately.
+
"""
trainer = self.trainer
@@ -413,6 +414,7 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde
Returns:
The kwargs passed down to the hooks.
+
"""
kwargs["batch"] = batch
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py
index b449355582c46..618717a69be43 100644
--- a/src/lightning/pytorch/loops/utilities.py
+++ b/src/lightning/pytorch/loops/utilities.py
@@ -40,6 +40,7 @@ def check_finite_loss(loss: Optional[Tensor]) -> None:
Args:
loss: the loss value to check to be finite
+
"""
if loss is not None and not torch.isfinite(loss).all():
raise ValueError(f"The loss returned in `training_step` is {loss}.")
@@ -52,8 +53,8 @@ def _parse_loop_limits(
max_epochs: Optional[int],
trainer: "pl.Trainer",
) -> Tuple[int, int]:
- """This utility computes the default values for the minimum and maximum number of steps and epochs given the
- values the user has selected.
+ """This utility computes the default values for the minimum and maximum number of steps and epochs given the values
+ the user has selected.
Args:
min_steps: Minimum number of steps.
@@ -64,6 +65,7 @@ def _parse_loop_limits(
Returns:
The parsed limits, with default values being set for the ones that the user did not specify.
+
"""
if max_epochs is None:
if max_steps == -1 and not any(isinstance(cb, Timer) for cb in trainer.callbacks):
@@ -89,8 +91,8 @@ def _parse_loop_limits(
@contextmanager
def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Generator[None, None, None]:
- """Blocks synchronization in :class:`~lightning.pytorch.strategies.parallel.ParallelStrategy`. This is useful
- for example when accumulating gradients to reduce communication when it is not needed.
+ """Blocks synchronization in :class:`~lightning.pytorch.strategies.parallel.ParallelStrategy`. This is useful for
+ example when accumulating gradients to reduce communication when it is not needed.
Args:
strategy: the strategy instance to use.
@@ -98,6 +100,7 @@ def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Gen
Returns:
context manager with sync behaviour off
+
"""
if isinstance(strategy, ParallelStrategy) and block:
with strategy.block_backward_sync():
@@ -115,6 +118,7 @@ def _is_max_limit_reached(current: int, maximum: int = -1) -> bool:
Returns:
bool: whether the limit has been reached
+
"""
return maximum != -1 and current >= maximum
diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py
index 1480163dc57c0..9b86b5db3055d 100644
--- a/src/lightning/pytorch/overrides/distributed.py
+++ b/src/lightning/pytorch/overrides/distributed.py
@@ -142,6 +142,7 @@ def _register_ddp_comm_hook(
ddp_comm_hook=powerSGD.powerSGD_hook,
ddp_comm_wrapper=default.fp16_compress_wrapper,
)
+
"""
if ddp_comm_hook is None:
return
@@ -191,15 +192,16 @@ def _sync_module_states(module: torch.nn.Module) -> None:
class UnrepeatedDistributedSampler(DistributedSampler):
- """A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches
- per process to be off-by-one from each other. This makes this sampler usable for predictions (it's
- deterministic and doesn't require shuffling). It is potentially unsafe to use this sampler for training,
- because during training the DistributedDataParallel syncs buffers on each forward pass, so it could freeze if
- one of the processes runs one fewer batch. During prediction, buffers are only synced on the first batch, so
- this is safe to use as long as each process runs at least one batch. We verify this in an assert.
+ """A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches per
+ process to be off-by-one from each other. This makes this sampler usable for predictions (it's deterministic and
+ doesn't require shuffling). It is potentially unsafe to use this sampler for training, because during training the
+ DistributedDataParallel syncs buffers on each forward pass, so it could freeze if one of the processes runs one
+ fewer batch. During prediction, buffers are only synced on the first batch, so this is safe to use as long as each
+ process runs at least one batch. We verify this in an assert.
Taken from https://github.com/jpuigcerver/PyLaia/blob/v1.0.0/laia/data/unpadded_distributed_sampler.py and
https://github.com/pytorch/pytorch/issues/25162#issuecomment-634146002
+
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
diff --git a/src/lightning/pytorch/plugins/layer_sync.py b/src/lightning/pytorch/plugins/layer_sync.py
index e777eb04ea8d1..faa1ab23f1f2c 100644
--- a/src/lightning/pytorch/plugins/layer_sync.py
+++ b/src/lightning/pytorch/plugins/layer_sync.py
@@ -33,10 +33,10 @@ def revert(self, model: Module) -> Module:
class TorchSyncBatchNorm(LayerSync):
- """A plugin that wraps all batch normalization layers of a model with synchronization logic for
- multiprocessing.
+ """A plugin that wraps all batch normalization layers of a model with synchronization logic for multiprocessing.
This plugin has no effect in single-device operation.
+
"""
def apply(self, model: Module) -> Module:
@@ -50,6 +50,7 @@ def apply(self, model: Module) -> Module:
Return:
LightningModule with batchnorm layers synchronized within the process groups.
+
"""
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -61,6 +62,7 @@ def revert(self, model: Module) -> Module:
Return:
LightningModule with regular batchnorm layers that will no longer sync across processes.
+
"""
# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
# Original author: Kapil Yedidi (@kapily)
diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py
index 510b4c7d428ed..99e291b7338d9 100644
--- a/src/lightning/pytorch/plugins/precision/deepspeed.py
+++ b/src/lightning/pytorch/plugins/precision/deepspeed.py
@@ -50,6 +50,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
Raises:
ValueError:
If unsupported ``precision`` is provided.
+
"""
def __init__(self, precision: _PRECISION_INPUT) -> None:
@@ -104,6 +105,7 @@ def backward( # type: ignore[override]
optimizer: ignored for DeepSpeed
\*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
\**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
+
"""
if is_overridden("backward", model):
warning_cache.warn(
diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py
index df2387ba80326..a6cc31d9b6974 100644
--- a/src/lightning/pytorch/plugins/precision/double.py
+++ b/src/lightning/pytorch/plugins/precision/double.py
@@ -39,6 +39,7 @@ def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
@@ -50,6 +51,7 @@ def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py
index 505496989971a..befa2d9bd0856 100644
--- a/src/lightning/pytorch/plugins/precision/fsdp.py
+++ b/src/lightning/pytorch/plugins/precision/fsdp.py
@@ -29,6 +29,7 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
"""AMP for Fully Sharded Data Parallel (FSDP) Training.
.. warning:: This is an :ref:`experimental ` feature.
+
"""
def __init__(
@@ -79,6 +80,7 @@ def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.mixed_precision_config.param_dtype)
@@ -90,5 +92,6 @@ def forward_context(self) -> Generator[None, None, None]:
"""For FSDP, this context manager is a no-op since conversion is already handled internally.
See: https://pytorch.org/docs/stable/fsdp.html for more details on mixed precision.
+
"""
yield
diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py
index dcafa3b33fd53..9e2ed0a6a536e 100644
--- a/src/lightning/pytorch/plugins/precision/half.py
+++ b/src/lightning/pytorch/plugins/precision/half.py
@@ -28,6 +28,7 @@ class HalfPrecisionPlugin(PrecisionPlugin):
Args:
precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``).
+
"""
precision: Literal["bf16-true", "16-true"] = "16-true"
@@ -44,6 +45,7 @@ def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.
See: :meth:`torch.set_default_dtype`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
@@ -52,10 +54,10 @@ def init_context(self) -> Generator[None, None, None]:
@contextmanager
def forward_context(self) -> Generator[None, None, None]:
- """A context manager to change the default tensor type when tensors get created during the module's
- forward.
+ """A context manager to change the default tensor type when tensors get created during the module's forward.
See: :meth:`torch.set_default_tensor_type`
+
"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
diff --git a/src/lightning/pytorch/plugins/precision/precision_plugin.py b/src/lightning/pytorch/plugins/precision/precision_plugin.py
index 89fa734013083..0c083f9427b75 100644
--- a/src/lightning/pytorch/plugins/precision/precision_plugin.py
+++ b/src/lightning/pytorch/plugins/precision/precision_plugin.py
@@ -32,6 +32,7 @@ class PrecisionPlugin(FabricPrecision, CheckpointHooks):
"""Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
+
"""
def connect(
@@ -63,6 +64,7 @@ def backward( # type: ignore[override]
\*args: Positional arguments intended for the actual function that performs the backward, like
:meth:`~torch.Tensor.backward`.
\**kwargs: Keyword arguments for the same purpose as ``*args``.
+
"""
model.backward(tensor, *args, **kwargs)
diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py
index cc1600af34784..f758439646d0a 100644
--- a/src/lightning/pytorch/profilers/advanced.py
+++ b/src/lightning/pytorch/profilers/advanced.py
@@ -25,10 +25,11 @@
class AdvancedProfiler(Profiler):
- """This profiler uses Python's cProfiler to record more detailed information about time spent in each function
- call recorded during a given action.
+ """This profiler uses Python's cProfiler to record more detailed information about time spent in each function call
+ recorded during a given action.
The output is quite verbose and you should only use this if you want very detailed reports.
+
"""
def __init__(
diff --git a/src/lightning/pytorch/profilers/base.py b/src/lightning/pytorch/profilers/base.py
index 5bf3b0e6e812c..be2c50a8d0bbd 100644
--- a/src/lightning/pytorch/profilers/base.py
+++ b/src/lightning/pytorch/profilers/base.py
@@ -20,6 +20,7 @@ class PassThroughProfiler(Profiler):
"""This class should be used when you don't want the (small) overhead of profiling.
The Trainer uses this class by default.
+
"""
def start(self, action_name: str) -> None:
diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py
index 5bc23251a873a..d7f168b3d2aa1 100644
--- a/src/lightning/pytorch/profilers/profiler.py
+++ b/src/lightning/pytorch/profilers/profiler.py
@@ -62,6 +62,7 @@ def profile(self, action_name: str) -> Generator:
The profiler will start once you've entered the context and will automatically
stop once you exit the code block.
+
"""
try:
self.start(action_name)
@@ -134,6 +135,7 @@ def teardown(self, stage: Optional[str]) -> None:
"""Execute arbitrary post-profiling tear-down steps.
Closes the currently open file and stream.
+
"""
self._write_stream = None
if self._output_file is not None:
diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py
index fe3ab1c18968c..0b486f1aa587d 100644
--- a/src/lightning/pytorch/profilers/pytorch.py
+++ b/src/lightning/pytorch/profilers/pytorch.py
@@ -44,8 +44,7 @@
class RegisterRecordFunction:
- """While profiling autograd operations, this class will add labels for module names around the forward
- function.
+ """While profiling autograd operations, this class will add labels for module names around the forward function.
The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows:
@@ -60,6 +59,7 @@ class RegisterRecordFunction:
from lightning.pytorch import Trainer, seed_everything
with RegisterRecordFunction(model):
out = model(batch)
+
"""
def __init__(self, model: nn.Module) -> None:
@@ -288,6 +288,7 @@ def __init__(
If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``.
If arg ``schedule`` is not a ``Callable``.
If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``.
+
"""
super().__init__(dirpath=dirpath, filename=filename)
diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py
index 3af44d4178ab5..528290545e9ab 100644
--- a/src/lightning/pytorch/profilers/simple.py
+++ b/src/lightning/pytorch/profilers/simple.py
@@ -32,8 +32,8 @@
class SimpleProfiler(Profiler):
- """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each
- action and the total time spent over the entire training run."""
+ """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action
+ and the total time spent over the entire training run."""
def __init__(
self,
diff --git a/src/lightning/pytorch/profilers/xla.py b/src/lightning/pytorch/profilers/xla.py
index b6ebe70fd283e..2d1db1d3e5e15 100644
--- a/src/lightning/pytorch/profilers/xla.py
+++ b/src/lightning/pytorch/profilers/xla.py
@@ -31,12 +31,13 @@ class XLAProfiler(Profiler):
}
def __init__(self, port: int = 9012) -> None:
- """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud
- TPU performance tools.
+ """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU
+ performance tools.
Args:
port: the port to start the profiler server on. An exception is
raised if the provided port is invalid or busy.
+
"""
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
diff --git a/src/lightning/pytorch/serve/servable_module.py b/src/lightning/pytorch/serve/servable_module.py
index 33efa9956a16f..f715f4b3cad9d 100644
--- a/src/lightning/pytorch/serve/servable_module.py
+++ b/src/lightning/pytorch/serve/servable_module.py
@@ -52,6 +52,7 @@ def configure_response(self):
)
trainer.fit(ServableBoringModel())
assert serve_cb.resp.json() == {"output": [0, 1]}
+
"""
@abstractmethod
@@ -67,6 +68,7 @@ def configure_serialization(self) -> Tuple[Dict[str, Callable], Dict[str, Callab
The second dictionary contains the name of the ``serve_step`` output variables name as its keys
and the associated serialization function (e.g function to convert a tensors into payload).
+
"""
@abstractmethod
@@ -84,6 +86,7 @@ def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
Return:
- ``dict`` - A dictionary with their associated tensors.
+
"""
@abstractmethod
diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py
index e6db99091eca6..9e669a0455e1c 100644
--- a/src/lightning/pytorch/serve/servable_module_validator.py
+++ b/src/lightning/pytorch/serve/servable_module_validator.py
@@ -36,6 +36,7 @@ class ServableModuleValidator(Callback):
port: The port associated with the server.
timeout: Timeout period in seconds, that the process should wait for the server to start.
exit_on_failure: Whether to exit the process on failure.
+
"""
def __init__(
diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py
index fd1f72ab341be..55708768a6ca2 100644
--- a/src/lightning/pytorch/strategies/ddp.py
+++ b/src/lightning/pytorch/strategies/ddp.py
@@ -258,6 +258,7 @@ def optimizer_step(
closure: closure calculating the loss value
model: reference to the model, optionally defining optimizer step related hooks
**kwargs: Any extra arguments to ``optimizer.step``
+
"""
optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
@@ -323,6 +324,7 @@ def reduce(
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
+
"""
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index 82e9b108fe91d..8b9da45cd1d1c 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -246,6 +246,7 @@ def __init__(
load_full_weights: True when loading a single checkpoint file containing the model state dict
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.
+
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
@@ -389,6 +390,7 @@ def _setup_model_and_optimizers(
Return:
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
deepspeed optimizer.
+
"""
if len(optimizers) != 1:
raise ValueError(
@@ -414,6 +416,7 @@ def _setup_model_and_optimizer(
"""Initialize one model and one optimizer with an optional learning rate scheduler.
This calls :func:`deepspeed.initialize` internally.
+
"""
import deepspeed
@@ -577,6 +580,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
+
"""
if trainer.state.fn != TrainerFn.FITTING:
return
@@ -739,6 +743,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Op
Raises:
TypeError:
If ``storage_options`` arg is passed in
+
"""
# broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath
filepath = self.broadcast(filepath)
@@ -808,12 +813,13 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
self._restore_zero_state(checkpoint)
def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
- """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be
- sharded across processes before loading the state dictionary when using ZeRO stage 3. This is then
- automatically synced across processes.
+ """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded
+ across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced
+ across processes.
Args:
ckpt: The ckpt file.
+
"""
import deepspeed
diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py
index 9555d105eeb71..c160daba6f456 100644
--- a/src/lightning/pytorch/strategies/fsdp.py
+++ b/src/lightning/pytorch/strategies/fsdp.py
@@ -117,6 +117,7 @@ class FSDPStrategy(ParallelStrategy):
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
\**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
+
"""
strategy_name = "fsdp"
@@ -172,6 +173,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:
To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
dict.
+
"""
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
@@ -386,6 +388,7 @@ def reduce(
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
+
"""
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py
index ebd30a19ee287..6b087db5e4da4 100644
--- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py
+++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py
@@ -60,6 +60,7 @@ class _MultiProcessingLauncher(_Launcher):
- 'fork': Preferable for IPython/Jupyter environments where 'spawn' is not available. Not available on
the Windows platform for example.
- 'forkserver': Alternative implementation to 'fork'.
+
"""
def __init__(
@@ -93,6 +94,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer` for which
a selected set of attributes get restored in the main process after processes join.
**kwargs: Optional keyword arguments to be passed to the given function.
+
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()
@@ -198,8 +200,8 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
- """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process.
- To avoid issues with memory sharing, we cast the data to numpy.
+ """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
+ avoid issues with memory sharing, we cast the data to numpy.
Args:
trainer: reference to the Trainer.
@@ -207,6 +209,7 @@ def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
Returns:
A dictionary with items to send back to the main process where :meth:`update_main_process_results` will
process this output.
+
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
@@ -263,6 +266,7 @@ class _GlobalStateSnapshot:
# in worker process
snapshot.restore()
+
"""
use_deterministic_algorithms: bool
@@ -272,8 +276,7 @@ class _GlobalStateSnapshot:
@classmethod
def capture(cls) -> "_GlobalStateSnapshot":
- """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker
- process."""
+ """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker process."""
return cls(
use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
use_deterministic_algorithms_warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py
index 0a0170f2eb39e..5afdcbec2fc0f 100644
--- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py
+++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py
@@ -67,6 +67,7 @@ class _SubprocessScriptLauncher(_Launcher):
cluster_environment: A cluster environment that provides access to world size, node rank, etc.
num_processes: The number of processes to launch in the current node.
num_nodes: The total number of nodes that participate in this process group.
+
"""
def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None:
@@ -89,6 +90,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
*args: Optional positional arguments to be passed to the given function.
trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer`.
**kwargs: Optional keyword arguments to be passed to the given function.
+
"""
if not self.cluster_environment.creates_processes_externally:
self._call_children_scripts()
diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py
index 961bc9bbb6188..032ec26150a5b 100644
--- a/src/lightning/pytorch/strategies/launchers/xla.py
+++ b/src/lightning/pytorch/strategies/launchers/xla.py
@@ -31,8 +31,8 @@
class _XLALauncher(_MultiProcessingLauncher):
- r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at
- the end.
+ r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
+ end.
The main process in which this launcher is invoked creates N so-called worker processes (using the
`torch_xla` :func:`xmp.spawn`) that run the given function.
@@ -44,6 +44,7 @@ class _XLALauncher(_MultiProcessingLauncher):
Args:
strategy: A reference to the strategy that is used together with this launcher
+
"""
def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None:
@@ -67,6 +68,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer` for which
a selected set of attributes get restored in the main process after processes join.
**kwargs: Optional keyword arguments to be passed to the given function.
+
"""
using_pjrt = _using_pjrt()
# pjrt requires that the queue is serializable
diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py
index 30439cdbd4171..33dcd4be0baa5 100644
--- a/src/lightning/pytorch/strategies/parallel.py
+++ b/src/lightning/pytorch/strategies/parallel.py
@@ -112,6 +112,7 @@ def block_backward_sync(self) -> Generator:
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
+
"""
if isinstance(self.model, pl.utilities.types.DistributedDataParallel):
with self.model.no_sync():
diff --git a/src/lightning/pytorch/strategies/single_device.py b/src/lightning/pytorch/strategies/single_device.py
index 8083cccec8c52..a9809abe7c430 100644
--- a/src/lightning/pytorch/strategies/single_device.py
+++ b/src/lightning/pytorch/strategies/single_device.py
@@ -57,6 +57,7 @@ def reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tenso
Return:
the unmodified input as reduction is not needed for single process operation
+
"""
return tensor
diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py
index 1b575e0d6b40a..6fb5636c2f0d9 100644
--- a/src/lightning/pytorch/strategies/strategy.py
+++ b/src/lightning/pytorch/strategies/strategy.py
@@ -120,6 +120,7 @@ def setup_environment(self) -> None:
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
environment before setup is complete.
+
"""
assert self.accelerator is not None
self.accelerator.setup_device(self.root_device)
@@ -129,6 +130,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
+
"""
if trainer.state.fn != TrainerFn.FITTING:
return
@@ -140,6 +142,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the trainer instance
+
"""
assert self.accelerator is not None
self.accelerator.setup(trainer)
@@ -161,6 +164,7 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""Returns state of an optimizer.
Allows for syncing/collating optimizer state from processes in custom strategies.
+
"""
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
@@ -189,6 +193,7 @@ def backward(
\*args: Positional arguments that get passed down to the precision plugin's backward, intended as arguments
for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`.
\**kwargs: Keyword arguments for the same purpose as ``*args``.
+
"""
self.pre_backward(closure_loss)
assert self.lightning_module is not None
@@ -215,6 +220,7 @@ def optimizer_step(
closure: closure calculating the loss value
model: reference to the model, optionally defining optimizer step related hooks
\**kwargs: Keyword arguments to ``optimizer.step``
+
"""
model = model or self.lightning_module
# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
@@ -226,6 +232,7 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
The returned objects are expected to be in the same order they were passed in. The default implementation will
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
+
"""
# TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324
model = self._setup_model(model)
@@ -252,6 +259,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dat
batch: The batch of samples to move to the correct device
device: The target device
dataloader_idx: The index of the dataloader to which the batch belongs.
+
"""
model = self.lightning_module
device = device or self.root_device
@@ -287,6 +295,7 @@ def reduce(
group: the process group to reduce
reduce_op: the reduction operation. Defaults to 'mean'.
Can also be a string 'sum' or ReduceOp.
+
"""
@abstractmethod
@@ -295,6 +304,7 @@ def barrier(self, name: Optional[str] = None) -> None:
Args:
name: an optional name to pass into barrier.
+
"""
@abstractmethod
@@ -304,6 +314,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
Args:
obj: the object to broadcast
src: source rank
+
"""
@abstractmethod
@@ -314,6 +325,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
tensor: the tensor to all_gather
group: the process group to gather results from
sync_grads: flag that allows users to synchronize gradients for all_gather op
+
"""
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
@@ -358,6 +370,7 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
"""The actual training step.
See :meth:`~lightning.pytorch.core.module.LightningModule.training_step` for more details
+
"""
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
assert self.lightning_module is not None
@@ -371,6 +384,7 @@ def post_training_step(self) -> None:
"""This hook is deprecated.
Override :meth:`training_step` instead.
+
"""
pass
@@ -378,6 +392,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
"""The actual validation step.
See :meth:`~lightning.pytorch.core.module.LightningModule.validation_step` for more details
+
"""
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
assert self.lightning_module is not None
@@ -391,6 +406,7 @@ def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
"""The actual test step.
See :meth:`~lightning.pytorch.core.module.LightningModule.test_step` for more details
+
"""
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
assert self.lightning_module is not None
@@ -404,6 +420,7 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Any:
"""The actual predict step.
See :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` for more details
+
"""
args, kwargs = self.precision_plugin.convert_input((args, kwargs))
assert self.lightning_module is not None
@@ -418,16 +435,18 @@ def process_dataloader(self, dataloader: object) -> object:
Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
+
"""
return dataloader
@property
def restore_checkpoint_after_setup(self) -> bool:
- """Override to delay restoring from checkpoint till after the setup phase has completed. This is useful
- when the strategy requires all the setup hooks to run before loading checkpoint.
+ """Override to delay restoring from checkpoint till after the setup phase has completed. This is useful when
+ the strategy requires all the setup hooks to run before loading checkpoint.
Returns:
If ``True``, restore checkpoint after strategy setup.
+
"""
return False
@@ -436,6 +455,7 @@ def lightning_restore_optimizer(self) -> bool:
"""Override to disable Lightning restoring optimizers/schedulers.
This is useful for strategies which manage restoring optimizers/schedulers.
+
"""
return True
@@ -458,6 +478,7 @@ def save_checkpoint(
checkpoint: dict containing model and trainer state
filepath: write-target file's path
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
+
"""
if self.is_global_zero:
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
@@ -467,6 +488,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
Args:
filepath: Path to checkpoint
+
"""
if self.is_global_zero:
self.checkpoint_io.remove_checkpoint(filepath)
@@ -478,6 +500,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
Args:
empty_init: Whether to initialize the model with empty weights (uninitialized memory).
If ``None``, the strategy will decide. Some strategies may not support all options.
+
"""
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
@@ -486,11 +509,11 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
@contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
- """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
- shard the model instantly, which is useful for extremely large models which can save memory and
- initialization time.
+ """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard
+ the model instantly, which is useful for extremely large models which can save memory and initialization time.
Returns: Model parallel context.
+
"""
yield
@@ -498,6 +521,7 @@ def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
+
"""
_optimizers_to_device(self.optimizers, torch.device("cpu"))
@@ -568,6 +592,7 @@ class _ForwardRedirection:
"""Implements the `forward-redirection`.
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
+
"""
def __call__(
@@ -584,6 +609,7 @@ def __call__(
`forward` method instead.
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
+
"""
assert method_name != "forward"
original_forward = original_module.forward
diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py
index 99ce29b7c61fc..39458a8cd7e13 100644
--- a/src/lightning/pytorch/strategies/xla.py
+++ b/src/lightning/pytorch/strategies/xla.py
@@ -266,6 +266,7 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
Args:
filepath: Path to checkpoint
+
"""
if self.local_rank == 0:
self.checkpoint_io.remove_checkpoint(filepath)
@@ -279,6 +280,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
sync_grads: flag that allows users to synchronize gradients for the all-gather operation.
Return:
A tensor of shape (world_size, ...)
+
"""
if not self._launched:
return tensor
diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py
index 1647794f23183..2eab1bac09c0f 100644
--- a/src/lightning/pytorch/trainer/call.py
+++ b/src/lightning/pytorch/trainer/call.py
@@ -29,13 +29,14 @@
def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
- r"""Error handling, intended to be used only for main trainer function entry points (fit, validate, test,
- predict) as all errors should funnel through them.
+ r"""Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
+ as all errors should funnel through them.
Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
+
"""
try:
if trainer.strategy.launcher is not None:
@@ -243,6 +244,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s
Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using
`_call_callback_hooks` because we have special logic for getting callback_states.
+
"""
pl_module = trainer.lightning_module
if pl_module:
diff --git a/src/lightning/pytorch/trainer/configuration_validator.py b/src/lightning/pytorch/trainer/configuration_validator.py
index bebe781f2e43b..4a9c9c45c4728 100644
--- a/src/lightning/pytorch/trainer/configuration_validator.py
+++ b/src/lightning/pytorch/trainer/configuration_validator.py
@@ -27,6 +27,7 @@ def _verify_loop_configurations(trainer: "pl.Trainer") -> None:
Args:
trainer: Lightning Trainer. Its `lightning_module` (the model) to check the configuration.
+
"""
model = trainer.lightning_module
diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py
index 0e52a4b2fb375..2bd55e03265f3 100644
--- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py
@@ -114,6 +114,7 @@ def __init__(
priorities which to take when:
A. Class > str
B. Strategy > Accelerator/precision/plugins
+
"""
self.use_distributed_sampler = use_distributed_sampler
_set_torch_flags(deterministic=deterministic, benchmark=benchmark)
@@ -187,6 +188,7 @@ def _check_config_and_set_final_flags(
4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others.
Additionally, other flags such as `precision` or `sync_batchnorm` can populate the list with the
corresponding plugin instances.
+
"""
if plugins is not None:
plugins = [plugins] if not isinstance(plugins, list) else plugins
@@ -456,8 +458,8 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "ddp"
def _check_strategy_and_fallback(self) -> None:
- """Checks edge cases when the strategy selection was a string input, and we need to fall back to a
- different choice depending on other parameters or the environment."""
+ """Checks edge cases when the strategy selection was a string input, and we need to fall back to a different
+ choice depending on other parameters or the environment."""
# current fallback and check logic only apply to user pass in str config and object config
# TODO this logic should apply to both str and object config
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag
diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py
index d649755172658..bc4ae50164462 100644
--- a/src/lightning/pytorch/trainer/connectors/callback_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py
@@ -161,6 +161,7 @@ def _attach_model_callbacks(self) -> None:
callbacks already present in the trainer callbacks list, it will replace them.
In addition, all :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.
+
"""
trainer = self.trainer
@@ -187,9 +188,9 @@ def _attach_model_callbacks(self) -> None:
@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
- """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint`
- callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is
- preserved, as well as the order of all other callbacks.
+ """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks
+ to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as
+ the order of all other callbacks.
Args:
callbacks: A list of callbacks.
@@ -197,6 +198,7 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
Return:
A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints
if there were any present in the input.
+
"""
tuner_callbacks: List[Callback] = []
other_callbacks: List[Callback] = []
diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py
index ed54a9b4ebb77..7b66048c46e17 100644
--- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py
@@ -71,6 +71,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
2. from fault-tolerant auto-saved checkpoint if found
3. from `checkpoint_path` file if provided
4. don't restore
+
"""
self._ckpt_path = checkpoint_path
if not checkpoint_path:
@@ -209,8 +210,7 @@ def _parse_ckpt_path(
return ckpt_path
def resume_end(self) -> None:
- """Signal the connector that all states have resumed and memory for the checkpoint object can be
- released."""
+ """Signal the connector that all states have resumed and memory for the checkpoint object can be released."""
assert self.trainer.state.fn is not None
if self._ckpt_path:
message = "Restored all states" if self.trainer.state.fn == TrainerFn.FITTING else "Loaded model weights"
@@ -235,6 +235,7 @@ def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
Args:
checkpoint_path: Path to a PyTorch Lightning checkpoint file.
+
"""
self.resume_start(checkpoint_path)
@@ -266,6 +267,7 @@ def restore_model(self) -> None:
Hooks are called first to give the LightningModule a chance to modify the contents, then finally the model gets
updated with the loaded weights.
+
"""
if not self._loaded_checkpoint:
return
@@ -281,6 +283,7 @@ def restore_training_state(self) -> None:
"""Restore the trainer state from the pre-loaded checkpoint.
This includes the precision settings, loop progress, optimizer states and learning rate scheduler states.
+
"""
if not self._loaded_checkpoint:
return
@@ -320,6 +323,7 @@ def restore_loops(self) -> None:
"""Restores the loop progress from the pre-loaded checkpoint.
Calls hooks on the loops to give it a chance to restore its state from the checkpoint.
+
"""
if not self._loaded_checkpoint:
return
@@ -420,6 +424,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
LightningDataModule.__class__.__qualname__: pl DataModule's state
}
+
"""
trainer = self.trainer
model = trainer.lightning_module
@@ -507,6 +512,7 @@ def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Op
name_key: file name prefix
Returns:
None if no-corresponding-file else maximum suffix number
+
"""
# check directory existence
fs, uri = url_to_fs(str(dir_path))
diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py
index 0b8726cbd11c6..424cfe7cf6d14 100644
--- a/src/lightning/pytorch/trainer/connectors/data_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/data_connector.py
@@ -185,6 +185,7 @@ def _prepare_dataloader(self, dataloader: object, shuffle: bool, mode: RunningSt
- Injecting a `DistributedDataSamplerWrapper` into the `DataLoader` if on a distributed environment
- Wrapping the dataloader based on strategy-specific logic
+
"""
# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
@@ -289,6 +290,7 @@ class _DataLoaderSource:
instance: A LightningModule, LightningDataModule, or (a collection of) iterable(s).
name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook
that returns the desired dataloader(s).
+
"""
instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]]
@@ -298,6 +300,7 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
"""Returns the dataloader from the source.
If the source is a module, the method with the corresponding :attr:`name` gets called.
+
"""
if isinstance(self.instance, pl.LightningModule):
return call._call_lightning_module_hook(self.instance.trainer, self.name, pl_module=self.instance)
@@ -311,6 +314,7 @@ def is_defined(self) -> bool:
"""Returns whether the source dataloader can be retrieved or not.
If the source is a module it checks that the method with given :attr:`name` is overridden.
+
"""
return not self.is_module() or is_overridden(self.name, self.instance)
@@ -318,6 +322,7 @@ def is_module(self) -> bool:
"""Returns whether the DataLoader source is a LightningModule or a LightningDataModule.
It does not check whether ``*_dataloader`` methods are actually overridden.
+
"""
return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule))
@@ -327,6 +332,7 @@ def _request_dataloader(data_source: _DataLoaderSource) -> Union[TRAIN_DATALOADE
Returns:
The requested dataloader
+
"""
with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler):
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py
index 773257f119f15..61dd62cf9a48c 100644
--- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py
@@ -86,6 +86,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
metrics: Metric values
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
+
"""
if not self.trainer.loggers or not metrics:
return
diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
index 06ce4d021d9e1..c74f34dbe1644 100644
--- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
+++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py
@@ -403,6 +403,7 @@ def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
"""Create one _ResultMetric object per value.
Value can be provided as a nested collection
+
"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric
@@ -493,6 +494,7 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non
if False, only ``torch.Tensors`` are reset,
if ``None``, both are.
fx: Function to reset
+
"""
for item in self.values():
requested_type = metrics is None or metrics ^ item.is_tensor
diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py
index 7e6b7cd0c5e91..a5a0a5b69368e 100644
--- a/src/lightning/pytorch/trainer/connectors/signal_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py
@@ -139,6 +139,7 @@ def _valid_signals() -> Set[signal.Signals]:
Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for
older Python versions.
+
"""
if _PYTHON_GREATER_EQUAL_3_8_0:
return signal.valid_signals()
diff --git a/src/lightning/pytorch/trainer/states.py b/src/lightning/pytorch/trainer/states.py
index 73b7cb71dcf82..d386538f416fb 100644
--- a/src/lightning/pytorch/trainer/states.py
+++ b/src/lightning/pytorch/trainer/states.py
@@ -53,6 +53,7 @@ class RunningStage(LightningEnum):
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
+
"""
TRAINING = "train"
diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py
index ba668009a53ef..1dfc680c85f99 100644
--- a/src/lightning/pytorch/trainer/trainer.py
+++ b/src/lightning/pytorch/trainer/trainer.py
@@ -301,6 +301,7 @@ def __init__(
MisconfigurationException:
If ``gradient_clip_algorithm`` is invalid.
If ``track_grad_norm`` is not a positive number or inf.
+
"""
super().__init__()
log.debug(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
@@ -533,6 +534,7 @@ def fit(
:class:`torch._dynamo.OptimizedModule` for torch versions greater than or equal to 2.0.0 .
For more information about multiple dataloaders, see this :ref:`section `.
+
"""
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
@@ -626,6 +628,7 @@ def validate(
RuntimeError:
If a compiled ``model`` is passed and the strategy is not supported.
+
"""
if model is None:
# do we still have a reference from a previous call?
@@ -697,8 +700,8 @@ def test(
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
- r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on
- your test set until you want to.
+ r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your
+ test set until you want to.
Args:
model: The model to test.
@@ -734,6 +737,7 @@ def test(
RuntimeError:
If a compiled ``model`` is passed and the strategy is not supported.
+
"""
if model is None:
# do we still have a reference from a previous call?
@@ -843,6 +847,7 @@ def predict(
If a compiled ``model`` is passed and the strategy is not supported.
See :ref:`Lightning inference section` for more.
+
"""
if model is None:
# do we still have a reference from a previous call?
@@ -1003,8 +1008,8 @@ def _run(
return results
def _teardown(self) -> None:
- """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
- Callback; those are handled by :meth:`_call_teardown_hook`."""
+ """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback;
+ those are handled by :meth:`_call_teardown_hook`."""
self.strategy.teardown()
loop = self._active_loop
# loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`
@@ -1075,8 +1080,8 @@ def __setup_profiler(self) -> None:
@contextmanager
def init_module(self, empty_init: Optional[bool] = None) -> Generator:
- """Tensors that you instantiate under this context manager will be created on the device right away and
- have the right data type depending on the precision setting in the Trainer.
+ """Tensors that you instantiate under this context manager will be created on the device right away and have
+ the right data type depending on the precision setting in the Trainer.
The parameters and tensors get created on the device and with the right data type right away without wasting
memory being allocated unnecessarily. The automatic device placement under this context manager is only
@@ -1086,6 +1091,7 @@ def init_module(self, empty_init: Optional[bool] = None) -> Generator:
empty_init: Whether to initialize the model with empty weights (uninitialized memory).
If ``None``, the strategy will decide. Some strategies may not support all options.
Set this to ``True`` if you are loading a checkpoint into a large model. Requires `torch >= 1.13`.
+
"""
if not _TORCH_GREATER_EQUAL_2_0 and self.strategy.root_device.type != "cpu":
rank_zero_warn(
@@ -1113,6 +1119,7 @@ def print(self, *args: Any, **kwargs: Any) -> None:
process in each machine.
Arguments passed to this method are forwarded to the Python built-in :func:`print` function.
+
"""
if self.local_rank == 0:
print(*args, **kwargs)
@@ -1210,6 +1217,7 @@ def model(self) -> Optional[torch.nn.Module]:
To access the pure LightningModule, use
:meth:`~lightning.pytorch.trainer.trainer.Trainer.lightning_module` instead.
+
"""
return self.strategy.model
@@ -1228,6 +1236,7 @@ def log_dir(self) -> Optional[str]:
def training_step(self, batch, batch_idx):
img = ...
save_img(img, self.trainer.log_dir)
+
"""
if len(self.loggers) > 0:
if not isinstance(self.loggers[0], TensorBoardLogger):
@@ -1249,6 +1258,7 @@ def is_global_zero(self) -> bool:
def training_step(self, batch, batch_idx):
if self.trainer.is_global_zero:
print("in node 0, accelerator 0")
+
"""
return self.strategy.is_global_zero
@@ -1272,6 +1282,7 @@ def default_root_dir(self) -> str:
"""The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
+
"""
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
@@ -1286,8 +1297,8 @@ def early_stopping_callback(self) -> Optional[EarlyStopping]:
@property
def early_stopping_callbacks(self) -> List[EarlyStopping]:
- """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in
- the Trainer.callbacks list."""
+ """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in the
+ Trainer.callbacks list."""
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]
@property
@@ -1299,8 +1310,8 @@ def checkpoint_callback(self) -> Optional[Checkpoint]:
@property
def checkpoint_callbacks(self) -> List[Checkpoint]:
- """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found
- in the Trainer.callbacks list."""
+ """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found in
+ the Trainer.callbacks list."""
return [c for c in self.callbacks if isinstance(c, Checkpoint)]
@property
@@ -1334,6 +1345,7 @@ def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
# you will be in charge of resetting this
trainer.ckpt_path = None
trainer.test(model)
+
"""
self._checkpoint_connector._ckpt_path = ckpt_path
self._checkpoint_connector._user_managed = bool(ckpt_path)
@@ -1351,6 +1363,7 @@ def save_checkpoint(
Raises:
AttributeError:
If the model is not attached to the Trainer before calling this method.
+
"""
if self.model is None:
raise AttributeError(
@@ -1422,6 +1435,7 @@ def sanity_checking(self) -> bool:
"""Whether sanity checking is running.
Useful to disable some hooks, logging or callbacks during the sanity checking.
+
"""
return self.state.stage == RunningStage.SANITY_CHECKING
@@ -1449,6 +1463,7 @@ def global_step(self) -> int:
"""The number of optimizer steps taken (does not reset each epoch).
This includes multiple optimizers (if enabled).
+
"""
return self.fit_loop.epoch_loop.global_step
@@ -1588,6 +1603,7 @@ def loggers(self) -> List[Logger]:
for logger in trainer.loggers:
logger.log_metrics({"foo": 1.0})
+
"""
return self._loggers
@@ -1607,6 +1623,7 @@ def training_step(self, batch, batch_idx):
callback_metrics = trainer.callback_metrics
assert callback_metrics["a_val"] == 2.0
+
"""
return self._logger_connector.callback_metrics
@@ -1616,6 +1633,7 @@ def logged_metrics(self) -> _OUT_DICT:
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the
:paramref:`~lightning.pytorch.core.module.LightningModule.log.logger` argument set.
+
"""
return self._logger_connector.logged_metrics
@@ -1625,6 +1643,7 @@ def progress_bar_metrics(self) -> _PBAR_DICT:
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the
:paramref:`~lightning.pytorch.core.module.LightningModule.log.prog_bar` argument set.
+
"""
return self._logger_connector.progress_bar_metrics
diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py
index 2b3ec7ef7f33a..e8ab5afbaa6e2 100644
--- a/src/lightning/pytorch/tuner/batch_size_scaling.py
+++ b/src/lightning/pytorch/tuner/batch_size_scaling.py
@@ -57,6 +57,7 @@ def _scale_batch_size(
- ``model``
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
+
"""
if trainer.fast_dev_run:
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
@@ -212,10 +213,10 @@ def _run_binary_scaling(
max_trials: int,
params: Dict[str, Any],
) -> int:
- """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
- encountered.
+ """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
Hereafter, the batch size is further refined using a binary search
+
"""
low = 1
high = None
@@ -289,6 +290,7 @@ def _adjust_batch_size(
Returns:
The new batch size for the next trial and a bool that signals whether the
new value is different than the previous batch size.
+
"""
model = trainer.lightning_module
batch_size = lightning_getattr(model, batch_arg_name)
diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py
index f83c390d271c3..b7c3aba5f4f07 100644
--- a/src/lightning/pytorch/tuner/lr_finder.py
+++ b/src/lightning/pytorch/tuner/lr_finder.py
@@ -88,6 +88,7 @@ class _LRFinder:
# Get suggestion
lr = lr_finder.suggestion()
+
"""
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None:
@@ -172,8 +173,8 @@ def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] =
return fig
def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]:
- """This will propose a suggestion for an initial learning rate based on the point with the steepest
- negative gradient.
+ """This will propose a suggestion for an initial learning rate based on the point with the steepest negative
+ gradient.
Args:
skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates
@@ -182,6 +183,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]
Returns:
The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few
loss samples.
+
"""
losses = torch.tensor(self.results["loss"][skip_begin:-skip_end])
losses = losses[torch.isfinite(losses)]
@@ -215,8 +217,8 @@ def _lr_find(
update_attr: bool = False,
attr_name: str = "",
) -> Optional[_LRFinder]:
- """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
- picking a good starting learning rate.
+ """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking
+ a good starting learning rate.
Args:
trainer: A Trainer instance.
@@ -235,6 +237,7 @@ def _lr_find(
update_attr: Whether to update the learning rate attribute or not.
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
automatically detected. Otherwise, set the name here.
+
"""
if trainer.fast_dev_run:
rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.")
@@ -342,8 +345,8 @@ def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) ->
class _LRCallback(Callback):
- """Special callback used by the learning rate finder. This callback logs the learning rate before each batch
- and logs the corresponding loss after each batch.
+ """Special callback used by the learning rate finder. This callback logs the learning rate before each batch and
+ logs the corresponding loss after each batch.
Args:
num_training: number of iterations done by the learning rate finder
@@ -355,6 +358,7 @@ class _LRCallback(Callback):
beta: smoothing value, the loss being logged is a running average of
loss values logged until now. ``beta`` controls the forget rate i.e.
if ``beta=0`` all past information is ignored.
+
"""
def __init__(
@@ -443,6 +447,7 @@ class _LinearLR(_TORCH_LRSCHEDULER):
num_iter: the number of iterations over which the test occurs.
last_epoch: the index of last epoch. Default: -1.
+
"""
def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
@@ -478,6 +483,7 @@ class _ExponentialLR(_TORCH_LRSCHEDULER):
num_iter: the number of iterations over which the test occurs.
last_epoch: the index of last epoch. Default: -1.
+
"""
def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py
index 53b7b45210ef9..2fd03995497fd 100644
--- a/src/lightning/pytorch/tuner/tuning.py
+++ b/src/lightning/pytorch/tuner/tuning.py
@@ -39,8 +39,8 @@ def scale_batch_size(
max_trials: int = 25,
batch_arg_name: str = "batch_size",
) -> Optional[int]:
- """Iteratively try to find the largest batch size for a given model that does not give an out of memory
- (OOM) error.
+ """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
+ error.
Args:
model: Model to tune.
@@ -71,6 +71,7 @@ def scale_batch_size(
- ``model``
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
+
"""
_check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method)
_check_scale_batch_size_configuration(self._trainer)
@@ -149,6 +150,7 @@ def lr_find(
MisconfigurationException:
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden,
or if you are using more than one optimizer.
+
"""
if method != "fit":
raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.")
diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py
index 888a3b3755e1e..72d550ac108c7 100644
--- a/src/lightning/pytorch/utilities/argparse.py
+++ b/src/lightning/pytorch/utilities/argparse.py
@@ -38,6 +38,7 @@ def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argum
>>> _parse_env_variables(Trainer)
Namespace(devices=42)
>>> del os.environ["PL_TRAINER_DEVICES"]
+
"""
env_args = {}
for arg_name in inspect.signature(cls).parameters:
diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py
index 0e012dbae145b..d7194b919638f 100644
--- a/src/lightning/pytorch/utilities/combined_loader.py
+++ b/src/lightning/pytorch/utilities/combined_loader.py
@@ -238,6 +238,7 @@ class CombinedLoader(Iterable):
tensor([0, 1, 2, 3, 4]) batch_idx=0 dataloader_idx=1
tensor([5, 6, 7, 8, 9]) batch_idx=1 dataloader_idx=1
tensor([10, 11, 12, 13, 14]) batch_idx=2 dataloader_idx=1
+
"""
def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
diff --git a/src/lightning/pytorch/utilities/compile.py b/src/lightning/pytorch/utilities/compile.py
index ba9cd9c93f4c8..ea2dc146bfb36 100644
--- a/src/lightning/pytorch/utilities/compile.py
+++ b/src/lightning/pytorch/utilities/compile.py
@@ -79,6 +79,7 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod
returned by ``from_compiled``.
Note: this method will in-place modify the ``LightningModule`` that is passed in.
+
"""
if not _TORCH_GREATER_EQUAL_2_0:
raise ModuleNotFoundError("`to_uncompiled` requires torch>=2.0")
diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py
index ce69cdc4d664a..a42e3053a27d0 100644
--- a/src/lightning/pytorch/utilities/data.py
+++ b/src/lightning/pytorch/utilities/data.py
@@ -234,14 +234,15 @@ def _dataloader_init_kwargs_resolve_sampler(
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
- """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
- re-instantiation.
+ """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-
+ instantiation.
If the dataloader is being used for prediction, the sampler will be wrapped into an `_IndexBatchSamplerWrapper`, so
Lightning can keep track of its indices.
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
automatically, since `poptorch.DataLoader` will try to increase the batch_size
+
"""
is_predicting = mode == RunningStage.PREDICTING
batch_sampler = getattr(dataloader, "batch_sampler")
diff --git a/src/lightning/pytorch/utilities/exceptions.py b/src/lightning/pytorch/utilities/exceptions.py
index cb06cc7572647..b1ca189f69d82 100644
--- a/src/lightning/pytorch/utilities/exceptions.py
+++ b/src/lightning/pytorch/utilities/exceptions.py
@@ -25,6 +25,7 @@ class SIGTERMException(SystemExit):
For example, you could use the :class:`lightning.pytorch.callbacks.fault_tolerance.OnExceptionCheckpoint` callback
that saves a checkpoint for you when this exception is raised.
+
"""
diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py
index c6f0d062df209..21f737f8b69d2 100644
--- a/src/lightning/pytorch/utilities/grads.py
+++ b/src/lightning/pytorch/utilities/grads.py
@@ -35,6 +35,7 @@ def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator
norms: The dictionary of p-norms of each parameter's gradient and
a special entry for the total p-norm of the gradients viewed
as a single vector.
+
"""
norm_type = float(norm_type)
if norm_type <= 0:
diff --git a/src/lightning/pytorch/utilities/memory.py b/src/lightning/pytorch/utilities/memory.py
index 0922b63e0bb5b..698de64d48add 100644
--- a/src/lightning/pytorch/utilities/memory.py
+++ b/src/lightning/pytorch/utilities/memory.py
@@ -34,6 +34,7 @@ def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any:
Return:
out_dict: Dictionary with detached tensors
+
"""
def detach_and_move(t: Tensor, to_cpu: bool) -> Tensor:
diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py
index 40803650e5c42..f935a1061c81f 100644
--- a/src/lightning/pytorch/utilities/migration/migration.py
+++ b/src/lightning/pytorch/utilities/migration/migration.py
@@ -27,6 +27,7 @@
cp model.ckpt model.ckpt.backup
python -m lightning.pytorch.utilities.upgrade_checkpoint model.ckpt
+
"""
import re
from typing import Any, Callable, Dict, List
@@ -60,6 +61,7 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP
Version: 0.10.0
Commit: a5d1176
+
"""
keys_mapping = {
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
@@ -87,6 +89,7 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _
Version: 1.6.0
Commit: c67b075
PR: #13645, #11805
+
"""
global_step = checkpoint["global_step"]
checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()})
@@ -107,6 +110,7 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) ->
Version: 1.6.0
Commit: aea96e4
PR: #11805
+
"""
epoch = checkpoint["epoch"]
checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()})
@@ -121,6 +125,7 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
Version: 1.6.5
Commit: c67b075
PR: #13645
+
"""
global_step = checkpoint["global_step"]
checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"].setdefault("_batches_that_stepped", global_step)
@@ -218,6 +223,7 @@ def _drop_apex_amp_state(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
Version: 2.0.0
Commit: e544676ff434ed96c6dd3b4e73a708bcb27ebcf1
PR: #16149
+
"""
key = "amp_scaling_state"
if key in checkpoint:
@@ -234,6 +240,7 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE
Version: 2.0.0
Commit: 7807454
PR: #16337, #16172
+
"""
if "loops" not in checkpoint:
return checkpoint
@@ -265,13 +272,13 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE
def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
- """Adjusts the loop structure since it changed when the support for multiple optimizers in automatic
- optimization mode was removed. There is no longer a loop over optimizer, and hence no position to store for
- resuming the loop.
+ """Adjusts the loop structure since it changed when the support for multiple optimizers in automatic optimization
+ mode was removed. There is no longer a loop over optimizer, and hence no position to store for resuming the loop.
Version: 2.0.0
Commit: 6a56586
PR: #16539, #16598
+
"""
if "loops" not in checkpoint:
return checkpoint
diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py
index 49ae9132630e9..3929ea4a47cfd 100644
--- a/src/lightning/pytorch/utilities/migration/utils.py
+++ b/src/lightning/pytorch/utilities/migration/utils.py
@@ -48,6 +48,7 @@ def migrate_checkpoint(
Note:
The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large
checkpoints and objects that do not support being deep-copied.
+
"""
ckpt_version = _get_version(checkpoint)
if Version(ckpt_version) > Version(pl.__version__):
@@ -91,6 +92,7 @@ class pl_legacy_patch:
with pl_legacy_patch():
torch.load("path/to/legacy/checkpoint.ckpt")
+
"""
def __enter__(self) -> "pl_legacy_patch":
@@ -135,6 +137,7 @@ def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_P
"""Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user.
This function is used by the Lightning Trainer when resuming from a checkpoint.
+
"""
old_version = _get_version(checkpoint)
checkpoint, migrations = migrate_checkpoint(checkpoint)
@@ -182,6 +185,7 @@ class _RedirectingUnpickler(pickle._Unpickler):
In legacy versions of Lightning, callback classes got pickled into the checkpoint. These classes are defined in the
`pytorch_lightning` but need to be loaded from `lightning.pytorch`.
+
"""
def find_class(self, module: str, name: str) -> Any:
diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py
index 4476ac5b25fab..8c50181f3b086 100644
--- a/src/lightning/pytorch/utilities/model_summary/model_summary.py
+++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py
@@ -38,8 +38,8 @@
class LayerSummary:
- """Summary class for a single layer in a :class:`~lightning.pytorch.core.module.LightningModule`. It collects
- the following information:
+ """Summary class for a single layer in a :class:`~lightning.pytorch.core.module.LightningModule`. It collects the
+ following information:
- Type of the layer (e.g. Linear, BatchNorm1d, ...)
- Input shape
@@ -65,6 +65,7 @@ class LayerSummary:
Args:
module: A module to summarize
+
"""
def __init__(self, module: nn.Module) -> None:
@@ -78,13 +79,13 @@ def __del__(self) -> None:
self.detach_hook()
def _register_hook(self) -> Optional[RemovableHandle]:
- """Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If
- the hook is called, it will remove itself from the from the module, meaning that recursive models will only
- record their input- and output shapes once. Registering hooks on :class:`~torch.jit.ScriptModule` is not
- supported.
+ """Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the
+ hook is called, it will remove itself from the from the module, meaning that recursive models will only record
+ their input- and output shapes once. Registering hooks on :class:`~torch.jit.ScriptModule` is not supported.
Return:
A handle for the installed hook, or ``None`` if registering the hook is not possible.
+
"""
def hook(_: nn.Module, inp: Any, out: Any) -> None:
@@ -116,6 +117,7 @@ def detach_hook(self) -> None:
"""Removes the forward hook if it was not already removed in the forward pass.
Will be called after the summary is created.
+
"""
if self._hook_handle is not None:
self._hook_handle.remove()
@@ -194,6 +196,7 @@ class ModelSummary:
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
+
"""
def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
@@ -303,6 +306,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
"""Makes a summary listing with:
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
+
"""
arrays = [
(" ", list(map(str, range(len(self._layer_summary))))),
@@ -361,8 +365,8 @@ def _format_summary_table(
model_size: float,
*cols: Tuple[str, List[str]],
) -> str:
- """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one
- big string defining the summary table that are nicely formatted."""
+ """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big
+ string defining the summary table that are nicely formatted."""
n_rows = len(cols[0][1])
n_cols = 1 + len(cols)
@@ -425,6 +429,7 @@ def get_human_readable_count(number: int) -> str:
Return:
A string formatted according to the pattern described above.
+
"""
assert number >= 0
labels = PARAMETER_NUM_UNITS
@@ -463,5 +468,6 @@ def summarize(lightning_module: "pl.LightningModule", max_depth: int = 1) -> Mod
Return:
The model summary object
+
"""
return ModelSummary(lightning_module, max_depth=max_depth)
diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py
index fc84fbeb54b89..b9a4993941a40 100644
--- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py
+++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py
@@ -85,6 +85,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
"""Makes a summary listing with:
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
+
"""
arrays = [
(" ", list(map(str, range(len(self._layer_summary))))),
diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py
index 9b12b456db451..5f5ea505dc05a 100644
--- a/src/lightning/pytorch/utilities/parameter_tying.py
+++ b/src/lightning/pytorch/utilities/parameter_tying.py
@@ -15,6 +15,7 @@
Reference:
https://github.com/pytorch/fairseq/blob/1f7ef9ed1e1061f8c7f88f8b94c7186834398690/fairseq/trainer.py#L110-L118
+
"""
from typing import Dict, List, Optional
diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py
index 9958bc9f249a7..7ed77af18930a 100644
--- a/src/lightning/pytorch/utilities/parsing.py
+++ b/src/lightning/pytorch/utilities/parsing.py
@@ -120,6 +120,7 @@ def collect_init_args(
A list of dictionaries where each dictionary contains the arguments passed to the
constructor at that level. The last entry corresponds to the constructor call of the
most specific class in the hierarchy.
+
"""
_, _, _, local_vars = inspect.getargvalues(frame)
# frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy
@@ -216,6 +217,7 @@ class AttributeDict(Dict):
"key2": abc
"my-key": 3.14
"new_key": 42
+
"""
def __getattr__(self, key: str) -> Optional[Any]:
@@ -241,6 +243,7 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str)
Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams
namespace/dict, and the datamodule.
+
"""
holders: List[Any] = []
@@ -269,6 +272,7 @@ def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str
Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams
namespace/dict, and the datamodule, returns the last one that has it.
+
"""
holders = _lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
@@ -281,18 +285,20 @@ def lightning_hasattr(model: "pl.LightningModule", attribute: str) -> bool:
"""Special hasattr for Lightning.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
+
"""
return _lightning_get_first_attr_holder(model, attribute) is not None
def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[Any]:
- """Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and
- the datamodule.
+ """Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the
+ datamodule.
Raises:
AttributeError:
If ``model`` doesn't have ``attribute`` in any of
model namespace, the hparams namespace/dict, and the datamodule.
+
"""
holder = _lightning_get_first_attr_holder(model, attribute)
if holder is None:
@@ -307,13 +313,14 @@ def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[A
def lightning_setattr(model: "pl.LightningModule", attribute: str, value: Any) -> None:
- """Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict.
- Will also set the attribute on datamodule, if it exists.
+ """Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. Will
+ also set the attribute on datamodule, if it exists.
Raises:
AttributeError:
If ``model`` doesn't have ``attribute`` in any of
model namespace, the hparams namespace/dict, and the datamodule.
+
"""
holders = _lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py
index 50c88dbad6936..10badab69c32f 100644
--- a/src/lightning/pytorch/utilities/seed.py
+++ b/src/lightning/pytorch/utilities/seed.py
@@ -38,6 +38,7 @@ def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
>>> torch.rand(1)
tensor([0.7576])
+
"""
states = _collect_rng_states(include_cuda)
yield
diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py
index 732bc26cf5e8e..3c67260a88bed 100644
--- a/src/lightning/pytorch/utilities/testing/_runif.py
+++ b/src/lightning/pytorch/utilities/testing/_runif.py
@@ -65,6 +65,7 @@ def _runif_reasons(
psutil: Require that psutil is installed.
sklearn: Require that scikit-learn is installed.
onnx: Require that onnx is installed.
+
"""
reasons, kwargs = FabricRunIf(
diff --git a/src/lightning/store/store.py b/src/lightning/store/store.py
index ede05d5400b8a..f2389f8b8aa63 100644
--- a/src/lightning/store/store.py
+++ b/src/lightning/store/store.py
@@ -37,6 +37,7 @@ def upload_model(
The version of the model to be uploaded. If not provided, default will be latest (not overridden).
progress_bar:
A progress bar to show the uploading status. Disable this if not needed, by setting to `False`.
+
"""
client = _Client()
user = client.auth_service_get_user()
@@ -71,6 +72,7 @@ def download_model(
The version of the model to be uploaded. If not provided, default will be latest (not overridden).
progress_bar:
Show progress on download.
+
"""
client = _Client()
download_url = client.models_store_download_model(name=name, version=version).download_url
@@ -82,6 +84,7 @@ def list_models() -> List[V1Model]:
Returns:
A list of model objects.
+
"""
client = _Client()
# TODO: Allow passing this
diff --git a/tests/integrations_app/flagship/test_flashy.py b/tests/integrations_app/flagship/test_flashy.py
index c40cb155c2564..a69ed1fc5d61d 100644
--- a/tests/integrations_app/flagship/test_flashy.py
+++ b/tests/integrations_app/flagship/test_flashy.py
@@ -20,6 +20,7 @@ def validate_app_functionalities(app_page: "Page") -> None:
https://github.com/Lightning-AI/LAI-Flashy-App/blob/main/tests/test_app_gallery.py#L205
app_page: The UI page of the app to be validated.
+
"""
while True:
with contextlib.suppress(playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError):
diff --git a/tests/parity_pytorch/test_sync_batchnorm_parity.py b/tests/parity_pytorch/test_sync_batchnorm_parity.py
index 11aca5651055c..7a0c0658e3fc6 100644
--- a/tests/parity_pytorch/test_sync_batchnorm_parity.py
+++ b/tests/parity_pytorch/test_sync_batchnorm_parity.py
@@ -52,8 +52,8 @@ def train_dataloader(self):
@RunIf(min_cuda_gpus=2, standalone=True)
def test_sync_batchnorm_parity(tmpdir):
- """Test parity between 1) Training a synced batch-norm layer on 2 GPUs with batch size B per device 2) Training
- a batch-norm layer on CPU with twice the batch size."""
+ """Test parity between 1) Training a synced batch-norm layer on 2 GPUs with batch size B per device 2) Training a
+ batch-norm layer on CPU with twice the batch size."""
seed_everything(3)
# 2 GPUS, batch size = 4 per GPU => total batch size = 8
model = SyncBNModule(batch_size=4)
diff --git a/tests/tests_app/cli/test_cmd_launch.py b/tests/tests_app/cli/test_cmd_launch.py
index b1fdf89ac9606..9610c463c9b3c 100644
--- a/tests/tests_app/cli/test_cmd_launch.py
+++ b/tests/tests_app/cli/test_cmd_launch.py
@@ -27,6 +27,7 @@ def test_run_frontend(monkeypatch):
dispatcher.
This CLI call is made by Lightning AI and is not meant to be invoked by the user directly.
+
"""
runner = CliRunner()
diff --git a/tests/tests_app/cli/test_run_app.py b/tests/tests_app/cli/test_run_app.py
index c3ca5c1ac0955..a457b7b03e541 100644
--- a/tests/tests_app/cli/test_run_app.py
+++ b/tests/tests_app/cli/test_run_app.py
@@ -68,6 +68,7 @@ def test_lightning_run_app_cloud(mock_dispatch: mock.MagicMock, open_ui, caplog,
It tests it by checking if the click.launch is called with the right url if --open-ui was true and also checks the
call to `dispatch` for the right arguments.
+
"""
monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger())
@@ -116,6 +117,7 @@ def test_lightning_run_app_cloud_with_run_app_commands(mock_dispatch: mock.Magic
It tests it by checking if the click.launch is called with the right url if --open-ui was true and also checks the
call to `dispatch` for the right arguments.
+
"""
monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger())
@@ -182,6 +184,7 @@ def test_lightning_run_app_enable_basic_auth_passed(mock_dispatch: mock.MagicMoc
"""This test just validates the command has ran properly when --enable-basic-auth argument is passed.
It checks the call to `dispatch` for the right arguments.
+
"""
monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger())
diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py
index 1475ec57dc61c..14cfb9b9c409d 100644
--- a/tests/tests_app/components/python/test_python.py
+++ b/tests/tests_app/components/python/test_python.py
@@ -77,8 +77,8 @@ def test_tracer_python_script_with_kwargs():
def test_tracer_component_with_code():
- """This test ensures the Tracer Component gets the latest code from the code object that is provided and
- arguments are cleaned."""
+ """This test ensures the Tracer Component gets the latest code from the code object that is provided and arguments
+ are cleaned."""
drive = Drive("lit://code")
drive.component_name = "something"
@@ -125,8 +125,8 @@ def test_tracer_component_with_code():
def test_tracer_component_with_code_in_dir(tmp_path):
- """This test ensures the Tracer Component gets the latest code from the code object that is provided and
- arguments are cleaned."""
+ """This test ensures the Tracer Component gets the latest code from the code object that is provided and arguments
+ are cleaned."""
drive = Drive("lit://code")
drive.component_name = "something"
diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py
index 1a03b6e356bfb..01e7c1cb17ff0 100644
--- a/tests/tests_app/conftest.py
+++ b/tests/tests_app/conftest.py
@@ -95,6 +95,7 @@ def caplog(caplog):
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.
Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``.
+
"""
import logging
@@ -119,14 +120,15 @@ def caplog(caplog):
@pytest.fixture()
def patch_constants(request):
- """This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for
- the duration of a test.
+ """This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for the
+ duration of a test.
Example::
@pytest.mark.parametrize("patch_constants", [{"LIGHTNING_CLOUDSPACE_HOST": "any"}], indirect=True)
def test_my_stuff(patch_constants):
...
+
"""
# Set constants
old_constants = {}
diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py
index b35d06477238b..367e822b9de08 100644
--- a/tests/tests_app/core/test_lightning_api.py
+++ b/tests/tests_app/core/test_lightning_api.py
@@ -181,8 +181,8 @@ def maybe_apply_changes(self):
# FIXME: This test doesn't assert anything
@pytest.mark.skip(reason="TODO: Resolve flaky test.")
def test_app_stage_from_frontend():
- """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would
- start and stop the app."""
+ """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would start
+ and stop the app."""
app = AppStageTestingApp(FlowA(), log_level="debug")
app.stage = AppStage.BLOCKING
MultiProcessRuntime(app, start_server=True).dispatch()
@@ -193,6 +193,7 @@ def test_update_publish_state_and_maybe_refresh_ui():
- receives the state from the `publish_state_queue` and populates the app_state_store
- receives a notification to refresh the UI and makes a GET Request (streamlit).
+
"""
app = AppStageTestingApp(FlowA(), log_level="debug")
publish_state_queue = _MockQueue("publish_state_queue")
@@ -215,6 +216,7 @@ async def test_start_server(x_lightning_type, monkeypatch):
- the state on GET /api/v1/state
- push a delta when making a POST request to /api/v1/state
+
"""
class InfiniteQueue(_MockQueue):
diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py
index b42103fcb7439..aa0c97f65fa9f 100644
--- a/tests/tests_app/core/test_lightning_app.py
+++ b/tests/tests_app/core/test_lightning_app.py
@@ -1020,8 +1020,8 @@ def test_non_updated_flow(caplog):
def test_debug_mode_logging():
- """This test validates the DEBUG messages are collected when activated by the LightningApp(debug=True) and
- cleanup once finished."""
+ """This test validates the DEBUG messages are collected when activated by the LightningApp(debug=True) and cleanup
+ once finished."""
from lightning.app.core.app import _console
diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py
index 36388de3de1c9..dd89bb739f21e 100644
--- a/tests/tests_app/core/test_lightning_flow.py
+++ b/tests/tests_app/core/test_lightning_flow.py
@@ -77,8 +77,7 @@ def run(self):
],
)
def test_unsupported_attribute_declaration_outside_init_or_run(name, value):
- """Test that LightningFlow attributes (with a few exceptions) are not allowed to be declared outside
- __init__."""
+ """Test that LightningFlow attributes (with a few exceptions) are not allowed to be declared outside __init__."""
flow = EmptyFlow()
with pytest.raises(AttributeError, match=f"Cannot set attributes that were not defined in __init__: {name}"):
setattr(flow, name, value)
@@ -102,8 +101,8 @@ def test_unsupported_attribute_declaration_outside_init_or_run(name, value):
)
@pytest.mark.parametrize("defined", [False, True])
def test_unsupported_attribute_declaration_inside_run(defined, name, value):
- """Test that LightningFlow attributes can set LightningFlow or LightningWork inside its run method, but
- everything else needs to be defined in the __init__ method."""
+ """Test that LightningFlow attributes can set LightningFlow or LightningWork inside its run method, but everything
+ else needs to be defined in the __init__ method."""
class Flow(LightningFlow):
def __init__(self):
@@ -163,8 +162,8 @@ def run(self):
],
)
def test_supported_attribute_declaration_outside_init(name, value):
- """Test the custom LightningFlow setattr implementation for the few reserved attributes that are allowed to be
- set from outside __init__."""
+ """Test the custom LightningFlow setattr implementation for the few reserved attributes that are allowed to be set
+ from outside __init__."""
flow = EmptyFlow()
setattr(flow, name, value)
assert getattr(flow, name) == value
diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py
index 3b7437f95ba0a..d7870d3e09164 100644
--- a/tests/tests_app/core/test_lightning_work.py
+++ b/tests/tests_app/core/test_lightning_work.py
@@ -70,8 +70,7 @@ def run(self, *args, **kwargs):
def test_forgot_to_call_init():
- """This test validates the error message for user registering state without calling __init__ is
- comprehensible."""
+ """This test validates the error message for user registering state without calling __init__ is comprehensible."""
class W(LightningWork):
def __init__(self):
@@ -110,8 +109,8 @@ def test_unsupported_attribute_declaration_outside_init(name, value):
],
)
def test_supported_attribute_declaration_outside_init(name, value):
- """Test the custom LightningWork setattr implementation for the few reserved attributes that are allowed to be
- set from outside __init__."""
+ """Test the custom LightningWork setattr implementation for the few reserved attributes that are allowed to be set
+ from outside __init__."""
flow = EmptyWork()
setattr(flow, name, value)
assert getattr(flow, name) == value
diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py
index 8dd6d7d3a0b32..40abe438a3fb7 100644
--- a/tests/tests_app/core/test_queues.py
+++ b/tests/tests_app/core/test_queues.py
@@ -21,6 +21,7 @@ def test_queue_api(queue_type, monkeypatch):
"""Test the Queue API.
This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction
+
"""
import redis
diff --git a/tests/tests_app/frontend/panel/test_app_state_watcher.py b/tests/tests_app/frontend/panel/test_app_state_watcher.py
index a9c23b1619331..21faeba49592f 100644
--- a/tests/tests_app/frontend/panel/test_app_state_watcher.py
+++ b/tests/tests_app/frontend/panel/test_app_state_watcher.py
@@ -4,6 +4,7 @@
- to access and change the App state.
This is particularly useful for the PanelFrontend, but can be used by other Frontends too.
+
"""
# pylint: disable=protected-access
import os
@@ -38,6 +39,7 @@ def test_init(flow_state_state: dict):
- the .state is set
- the .state is scoped to the flow state
+
"""
# When
app = AppStateWatcher()
@@ -54,6 +56,7 @@ def test_update_flow_state(flow_state_state: dict):
"""We can update the state.
- the .state is scoped to the flow state
+
"""
app = AppStateWatcher()
org_state = app.state
@@ -67,6 +70,7 @@ def test_is_singleton():
Its key that __new__ and __init__ of AppStateWatcher is only called once. See
https://github.com/holoviz/param/issues/643
+
"""
# When
app1 = AppStateWatcher()
diff --git a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py b/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py
index 3244c07af0ae5..8e8bc4d415936 100644
--- a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py
+++ b/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py
@@ -1,6 +1,7 @@
"""The panel_serve_render_fn_or_file file gets run by Python to launch a Panel Server with Lightning.
These tests are for serving a render_fn function.
+
"""
import inspect
import os
@@ -41,6 +42,7 @@ def test_get_view_fn_args():
"""We have a helper get_view_fn function that create a function for our view.
If the render_fn provides an argument an AppStateWatcher is provided as argument
+
"""
result = _get_render_fn()
assert isinstance(result(), AppStateWatcher)
@@ -61,6 +63,7 @@ def test_get_view_fn_no_args():
"""We have a helper get_view_fn function that create a function for our view.
If the render_fn provides an argument an AppStateWatcher is provided as argument
+
"""
result = _get_render_fn()
assert result() == "no_args"
diff --git a/tests/tests_app/plugin/test_plugin.py b/tests/tests_app/plugin/test_plugin.py
index eaa75a60dc562..fef33ace92ca8 100644
--- a/tests/tests_app/plugin/test_plugin.py
+++ b/tests/tests_app/plugin/test_plugin.py
@@ -38,8 +38,8 @@ def raise_for_status(self):
def mock_requests_get(valid_url, return_value):
- """Used to replace `requests.get` with a function that returns the given value for the given valid URL and
- raises otherwise."""
+ """Used to replace `requests.get` with a function that returns the given value for the given valid URL and raises
+ otherwise."""
def inner(url):
if url == valid_url:
diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py
index af569b5ccc6ab..8df770d611b92 100644
--- a/tests/tests_app/runners/test_cloud.py
+++ b/tests/tests_app/runners/test_cloud.py
@@ -254,6 +254,7 @@ def test_running_deleted_app(self, tmpdir, cloud_backend, project_id):
"""Deleted apps show up in list apps but not in list instances.
This tests that we don't try to reacreate a previously deleted app.
+
"""
entrypoint = Path(tmpdir) / "entrypoint.py"
entrypoint.touch()
@@ -1881,10 +1882,11 @@ def test_print_specs(tmpdir, caplog, monkeypatch, print_format, expected):
def test_incompatible_cloud_compute_and_build_config(monkeypatch):
- """Test that an exception is raised when a build config has a custom image defined, but the cloud compute is
- the default.
+ """Test that an exception is raised when a build config has a custom image defined, but the cloud compute is the
+ default.
This combination is not supported by the platform.
+
"""
mock_client = mock.MagicMock()
cloud_backend = mock.MagicMock(client=mock_client)
diff --git a/tests/tests_app/storage/test_copier.py b/tests/tests_app/storage/test_copier.py
index fd4e274b91f62..f16ce57c7fd80 100644
--- a/tests/tests_app/storage/test_copier.py
+++ b/tests/tests_app/storage/test_copier.py
@@ -45,8 +45,8 @@ def test_copier_copies_all_files(fs_mock, stat_mock, dir_mock, tmpdir):
@mock.patch("lightning.app.storage.path.pathlib.Path.is_dir")
@mock.patch("lightning.app.storage.path.pathlib.Path.stat")
def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch):
- """Test that the Copier captures exceptions from the file copy and forwards them through the queue without
- raising it."""
+ """Test that the Copier captures exceptions from the file copy and forwards them through the queue without raising
+ it."""
stat_mock().st_size = 0
dir_mock.return_value = False
copy_request_queue = _MockQueue()
diff --git a/tests/tests_app/storage/test_path.py b/tests/tests_app/storage/test_path.py
index 423a4e7117e3e..56bf5dbc1ea4a 100644
--- a/tests/tests_app/storage/test_path.py
+++ b/tests/tests_app/storage/test_path.py
@@ -483,8 +483,8 @@ def run(self, src_path_0, local_path_0, nested_local_path, kwarg_path=None, nest
def test_path_as_argument_to_run_method():
- """Test that Path objects can be passed as arguments to the run() method of a Work in various ways such that
- the origin, consumer and queues get automatically attached."""
+ """Test that Path objects can be passed as arguments to the run() method of a Work in various ways such that the
+ origin, consumer and queues get automatically attached."""
root = RunPathFlow()
app = LightningApp(root)
MultiProcessRuntime(app, start_server=False).dispatch()
@@ -621,8 +621,8 @@ def test_path_response_not_matching_reqeuest(tmpdir):
def test_path_exists(tmpdir):
- """Test that the Path.exists() behaves as expected: First it should check if the file exists locally, and if
- not, send a message to the orchestrator to eventually check the existenc on the origin Work."""
+ """Test that the Path.exists() behaves as expected: First it should check if the file exists locally, and if not,
+ send a message to the orchestrator to eventually check the existenc on the origin Work."""
# Local Path (no Work queues attached)
assert not Path("file").exists()
assert Path(tmpdir).exists()
diff --git a/tests/tests_app/utilities/test_introspection.py b/tests/tests_app/utilities/test_introspection.py
index b3371e2348565..ce5d54f4e0746 100644
--- a/tests/tests_app/utilities/test_introspection.py
+++ b/tests/tests_app/utilities/test_introspection.py
@@ -38,8 +38,8 @@ def test_introspection_lightning():
@_RunIf(pl=True)
def test_introspection_lightning_overrides():
- """This test validates the scanner can find all the subclasses from primitives classes from PyTorch Lightning
- in the provided files."""
+ """This test validates the scanner can find all the subclasses from primitives classes from PyTorch Lightning in
+ the provided files."""
scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_cli.py")))
scan = scanner.scan()
assert set(scan) == {"LightningDataModule", "LightningModule"}
diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py
index d62cf2d6a2344..1134574f24814 100644
--- a/tests/tests_app/utilities/test_proxies.py
+++ b/tests/tests_app/utilities/test_proxies.py
@@ -288,8 +288,7 @@ def test_proxy_timeout():
@mock.patch("lightning.app.utilities.proxies._Copier")
def test_path_argument_to_transfer(*_):
- """Test that any Lightning Path objects passed to the run method get transferred automatically (if they
- exist)."""
+ """Test that any Lightning Path objects passed to the run method get transferred automatically (if they exist)."""
class TransferPathWork(LightningWork):
def run(self, *args, **kwargs):
@@ -372,8 +371,7 @@ def run(self, *args, **kwargs):
)
@mock.patch("lightning.app.utilities.proxies._Copier")
def test_path_attributes_to_transfer(_, origin, exists_remote, expected_get):
- """Test that any Lightning Path objects passed to the run method get transferred automatically (if they
- exist)."""
+ """Test that any Lightning Path objects passed to the run method get transferred automatically (if they exist)."""
path_mock = Mock()
path_mock.origin_name = origin
path_mock.exists_remote = Mock(return_value=exists_remote)
@@ -518,8 +516,8 @@ def run(self):
def test_work_state_observer():
- """Tests that the WorkStateObserver sends deltas to the queue when state residuals remain that haven't been
- handled by the setattr."""
+ """Tests that the WorkStateObserver sends deltas to the queue when state residuals remain that haven't been handled
+ by the setattr."""
class WorkWithoutSetattr(LightningWork):
def __init__(self):
diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py
index 9408ad4292683..37ae024335c2e 100644
--- a/tests/tests_fabric/accelerators/test_cuda.py
+++ b/tests/tests_fabric/accelerators/test_cuda.py
@@ -71,8 +71,8 @@ def test_set_cuda_device(_, set_device_mock):
@mock.patch("torch.cuda.is_available", return_value=True)
@mock.patch("torch.cuda.device_count", return_value=100)
def test_num_cuda_devices_without_nvml(*_):
- """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for
- determining CUDA availability."""
+ """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for determining
+ CUDA availability."""
num_cuda_devices.cache_clear()
assert is_cuda_available()
assert num_cuda_devices() == 100
diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py
index 77f78da4b3e11..b1dbc8963393e 100644
--- a/tests/tests_fabric/conftest.py
+++ b/tests/tests_fabric/conftest.py
@@ -121,6 +121,7 @@ def caplog(caplog):
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.
Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``.
+
"""
import logging
diff --git a/tests/tests_fabric/plugins/environments/test_mpi.py b/tests/tests_fabric/plugins/environments/test_mpi.py
index fb32b80a2af16..0f02572c5c95a 100644
--- a/tests/tests_fabric/plugins/environments/test_mpi.py
+++ b/tests/tests_fabric/plugins/environments/test_mpi.py
@@ -70,8 +70,7 @@ def test_default_attributes(monkeypatch):
def test_init_local_comm(monkeypatch):
- """Test that it can determine the node rank and local rank based on the hostnames of all participating
- nodes."""
+ """Test that it can determine the node rank and local rank based on the hostnames of all participating nodes."""
# pretend mpi4py is available
monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True)
mpi4py_mock = MagicMock()
diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py
index 76226ea2f24a4..e989534343b8c 100644
--- a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py
+++ b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py
@@ -28,6 +28,7 @@ def test_deepspeed_precision_choice(_, precision):
"""Test to ensure precision plugin is correctly chosen.
DeepSpeed handles precision via custom DeepSpeedPrecision.
+
"""
connector = _Connector(
accelerator="auto",
diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py
index c66782b61d9cc..117d0bb3dbd65 100644
--- a/tests/tests_fabric/strategies/test_ddp.py
+++ b/tests/tests_fabric/strategies/test_ddp.py
@@ -74,8 +74,7 @@ def test_ddp_no_backward_sync():
@mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel")
def test_ddp_extra_kwargs(ddp_mock):
- """Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel
- wrapper."""
+ """Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel wrapper."""
module = torch.nn.Linear(1, 1)
strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
strategy.setup_module(module)
diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py
index 84cd86ffc3bb5..2edded9ca8e87 100644
--- a/tests/tests_fabric/strategies/test_deepspeed.py
+++ b/tests/tests_fabric/strategies/test_deepspeed.py
@@ -399,6 +399,7 @@ def test_validate_parallel_devices_indices(device_indices):
"""Test that the strategy validates that it doesn't support selecting specific devices by index.
DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
+
"""
strategy = DeepSpeedStrategy(
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py
index ce879e050f7f5..3db926a6e978d 100644
--- a/tests/tests_fabric/strategies/test_deepspeed_integration.py
+++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py
@@ -245,6 +245,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
"""Test to ensure that we set up distributed communication correctly.
When using Windows, ranks environment variables should not be set, and DeepSpeed should handle this.
+
"""
fabric = Fabric(strategy=DeepSpeedStrategy(stage=3))
strategy = fabric._strategy
diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py
index dd782d636f09c..71de0b25348ac 100644
--- a/tests/tests_fabric/strategies/test_fsdp.py
+++ b/tests/tests_fabric/strategies/test_fsdp.py
@@ -361,8 +361,7 @@ def test_fsdp_load_unknown_checkpoint_type(tmp_path):
@RunIf(min_torch="2.0.0")
def test_fsdp_load_raw_checkpoint_validate_single_file(tmp_path):
- """Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict
- checkpoint."""
+ """Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint."""
strategy = FSDPStrategy()
model = Mock(spec=nn.Module)
path = tmp_path / "folder"
@@ -451,6 +450,7 @@ def guard_region(self, name: str):
This is confusing (since it logs "FAILED"), but more importantly the orphan rank will continue trying to execute
the rest of the test suite. So instead we add calls to `os._exit` which actually forces the process to shut
down.
+
"""
success = False
try:
diff --git a/tests/tests_fabric/strategies/test_strategy.py b/tests/tests_fabric/strategies/test_strategy.py
index 5eb21198473b5..ae101e4272500 100644
--- a/tests/tests_fabric/strategies/test_strategy.py
+++ b/tests/tests_fabric/strategies/test_strategy.py
@@ -155,8 +155,7 @@ def test_load_checkpoint_strict_loading(tmp_path):
def test_load_checkpoint_non_strict_loading(tmp_path):
- """Test that no error is raised if `strict=False` and state is requested that does not exist in the
- checkpoint."""
+ """Test that no error is raised if `strict=False` and state is requested that does not exist in the checkpoint."""
strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class
# objects with initial state
diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py
index 164882d013bf1..7db7fcf8dab24 100644
--- a/tests/tests_fabric/test_cli.py
+++ b/tests/tests_fabric/test_cli.py
@@ -69,8 +69,8 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script):
def test_cli_get_supported_strategies():
- """Test to ensure that when new strategies get added, we must consider updating the list of supported ones in
- the CLI."""
+ """Test to ensure that when new strategies get added, we must consider updating the list of supported ones in the
+ CLI."""
if _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available():
assert len(_get_supported_strategies()) == 7
assert "fsdp" in _get_supported_strategies()
diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py
index ae99397de3280..045ab64d1e002 100644
--- a/tests/tests_fabric/test_fabric.py
+++ b/tests/tests_fabric/test_fabric.py
@@ -120,8 +120,8 @@ def test_setup_compiled_module(setup_method):
@pytest.mark.parametrize("move_to_device", [True, False])
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device):
- """Test that `move_to_device` leads to parameters being moved to the correct device and that the device
- attributes on the wrapper are updated."""
+ """Test that `move_to_device` leads to parameters being moved to the correct device and that the device attributes
+ on the wrapper are updated."""
initial_device = torch.device(initial_device)
target_device = torch.device(target_device)
expected_device = target_device if move_to_device else initial_device
@@ -149,8 +149,7 @@ def test_setup_module_move_to_device(setup_method, move_to_device, accelerator,
@pytest.mark.parametrize("move_to_device", [True, False])
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module_parameters_on_different_devices(setup_method, move_to_device):
- """Test that a warning is emitted when model parameters are on a different device prior to calling
- `setup()`."""
+ """Test that a warning is emitted when model parameters are on a different device prior to calling `setup()`."""
device0 = torch.device("cpu")
device1 = torch.device("cuda", 0)
@@ -262,8 +261,7 @@ def test_setup_optimizers_twice_fails():
@pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, XLAStrategy])
def test_setup_optimizers_not_supported(strategy_cls):
- """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers
- independently."""
+ """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers independently."""
fabric = Fabric()
fabric._launched = True # pretend we have launched multiple processes
model = nn.Linear(1, 2)
@@ -275,8 +273,7 @@ def test_setup_optimizers_not_supported(strategy_cls):
@RunIf(min_cuda_gpus=1, min_torch="2.1")
def test_setup_optimizer_on_meta_device():
- """Test that the setup-methods validate that the optimizer doesn't have references to meta-device
- parameters."""
+ """Test that the setup-methods validate that the optimizer doesn't have references to meta-device parameters."""
fabric = Fabric(strategy="fsdp", devices=1)
fabric._launched = True # pretend we have launched multiple processes
with fabric.init_module(empty_init=True):
@@ -350,8 +347,8 @@ def run(_):
def test_setup_dataloaders_raises_for_unknown_custom_args():
- """Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's
- run method."""
+ """Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's run
+ method."""
class CustomDataLoader(DataLoader):
def __init__(self, new_arg, *args, **kwargs):
@@ -508,8 +505,7 @@ def test_seed_everything():
],
)
def test_setup_dataloaders_replace_custom_sampler(strategy):
- """Test that asking to replace a custom sampler results in an error when a distributed sampler would be
- needed."""
+ """Test that asking to replace a custom sampler results in an error when a distributed sampler would be needed."""
custom_sampler = Mock(spec=Sampler)
dataloader = DataLoader(Mock(), sampler=custom_sampler)
@@ -744,8 +740,7 @@ def run(self):
def test_module_sharding_context():
- """Test that the sharding context manager gets applied when the strategy supports it and is a no-op
- otherwise."""
+ """Test that the sharding context manager gets applied when the strategy supports it and is a no-op otherwise."""
fabric = Fabric()
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model():
diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py
index 79b0a33c576f5..1b0f7333db177 100644
--- a/tests/tests_fabric/test_wrappers.py
+++ b/tests/tests_fabric/test_wrappers.py
@@ -267,8 +267,8 @@ class DeviceModule(_DeviceDtypeModuleMixin):
def test_fabric_dataloader_iterator():
- """Test that the iteration over a FabricDataLoader wraps the iterator of the underlying dataloader (no
- automatic device placement)."""
+ """Test that the iteration over a FabricDataLoader wraps the iterator of the underlying dataloader (no automatic
+ device placement)."""
dataloader = DataLoader(range(5), batch_size=2)
fabric_dataloader = _FabricDataLoader(dataloader)
assert len(fabric_dataloader) == len(dataloader) == 3
diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py
index 533e52c299a52..82b601f36676d 100644
--- a/tests/tests_fabric/utilities/test_data.py
+++ b/tests/tests_fabric/utilities/test_data.py
@@ -48,13 +48,13 @@ def test_has_len():
def test_replace_dunder_methods_multiple_loaders_without_init():
"""In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__`
method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent
- class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error
- occured only sometimes because it depends on the order in which we are iterating over a set of classes we are
- patching.
+ class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured
+ only sometimes because it depends on the order in which we are iterating over a set of classes we are patching.
This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__`
and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and
b) the mechanism checking for presence of `__old__init__` works as expected.
+
"""
classes = [DataLoader]
for i in range(100):
@@ -253,10 +253,11 @@ def __init__(self, dataset, *args, batch_size=10, **kwargs):
def test_replace_dunder_methods_attrs():
- """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods`
- are correctly preserved even after reinstantiation.
+ """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` are
+ correctly preserved even after reinstantiation.
It also includes a custom `__setattr__`
+
"""
class Loader(DataLoader):
@@ -413,11 +414,12 @@ def __init__(self, randomize, *args, **kwargs):
def test_custom_batch_sampler():
- """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to
- properly reinstantiate the class, is invoked properly.
+ """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly
+ reinstantiate the class, is invoked properly.
It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
+
"""
class MyBatchSampler(BatchSampler):
@@ -456,8 +458,7 @@ def __init__(self, sampler, extra_arg, drop_last=True):
def test_custom_batch_sampler_no_sampler():
- """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler
- argument."""
+ """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler, without sampler argument.
@@ -511,10 +512,10 @@ def test_dataloader_kwargs_replacement_with_iterable_dataset():
def test_dataloader_kwargs_replacement_with_array_default_comparison():
- """Test that the comparison of attributes and default argument values works with arrays (truth value
- ambiguous).
+ """Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous).
Regression test for issue #15408.
+
"""
dataset = RandomDataset(5, 100)
diff --git a/tests/tests_fabric/utilities/test_device_dtype_mixin.py b/tests/tests_fabric/utilities/test_device_dtype_mixin.py
index bb1570eed0e20..1261ca5e0accb 100644
--- a/tests/tests_fabric/utilities/test_device_dtype_mixin.py
+++ b/tests/tests_fabric/utilities/test_device_dtype_mixin.py
@@ -36,8 +36,8 @@ def __init__(self) -> None:
)
@RunIf(min_cuda_gpus=1)
def test_submodules_device_and_dtype(dst_device_str, dst_type):
- """Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and
- the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule)."""
+ """Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and the
+ special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule)."""
dst_device = torch.device(dst_device_str)
model = TopModule()
assert model.device == torch.device("cpu")
diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py
index c5286d0dbe29d..4cecc4658c145 100644
--- a/tests/tests_fabric/utilities/test_logger.py
+++ b/tests/tests_fabric/utilities/test_logger.py
@@ -78,6 +78,7 @@ def test_sanitize_callable_params():
"""Callback function are not serializiable.
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
+
"""
def return_something():
diff --git a/tests/tests_fabric/utilities/test_warnings.py b/tests/tests_fabric/utilities/test_warnings.py
index a0165961cc69a..e39d52556bf24 100644
--- a/tests/tests_fabric/utilities/test_warnings.py
+++ b/tests/tests_fabric/utilities/test_warnings.py
@@ -14,6 +14,7 @@
"""Test that the warnings actually appear and they have the correct `stacklevel`
Needs to be run outside of `pytest` as it captures all the warnings.
+
"""
from contextlib import redirect_stderr
from io import StringIO
@@ -39,16 +40,16 @@
cache.deprecation("test7")
output = stderr.getvalue()
- assert "test_warnings.py:29: UserWarning: test1" in output
- assert "test_warnings.py:30: DeprecationWarning: test2" in output
+ assert "test_warnings.py:30: UserWarning: test1" in output
+ assert "test_warnings.py:31: DeprecationWarning: test2" in output
- assert "test_warnings.py:32: UserWarning: test3" in output
- assert "test_warnings.py:33: DeprecationWarning: test4" in output
+ assert "test_warnings.py:33: UserWarning: test3" in output
+ assert "test_warnings.py:34: DeprecationWarning: test4" in output
- assert "test_warnings.py:35: LightningDeprecationWarning: test5" in output
+ assert "test_warnings.py:36: LightningDeprecationWarning: test5" in output
- assert "test_warnings.py:38: UserWarning: test6" in output
- assert "test_warnings.py:39: LightningDeprecationWarning: test7" in output
+ assert "test_warnings.py:39: UserWarning: test6" in output
+ assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output
# check that logging is properly configured
import logging
diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py
index e724652a076ed..ac5b3443ae667 100644
--- a/tests/tests_pytorch/accelerators/test_cpu.py
+++ b/tests/tests_pytorch/accelerators/test_cpu.py
@@ -39,8 +39,8 @@ def test_get_device_stats(tmpdir):
@pytest.mark.parametrize("restore_after_pre_setup", [True, False])
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
- """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre-
- dispatch is called."""
+ """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- dispatch
+ is called."""
class TestPlugin(SingleDeviceStrategy):
setup_called = False
diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
index d4c8e85509b0d..7c6a75ea3bd90 100644
--- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
+++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
@@ -353,6 +353,7 @@ def test_train_progress_bar_update_amount(
At the end of the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh
rate.
+
"""
model = BoringModel()
progress_bar = TQDMProgressBar(refresh_rate=refresh_rate)
diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py
index 375734dead18f..e43122f4022fe 100644
--- a/tests/tests_pytorch/callbacks/test_early_stopping.py
+++ b/tests/tests_pytorch/callbacks/test_early_stopping.py
@@ -63,6 +63,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
https://github.com/Lightning-AI/lightning/issues/1464
https://github.com/Lightning-AI/lightning/issues/1463
+
"""
seed_everything(42)
model = ClassificationModel()
diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py
index 129b434e6b391..7fbb7f580b170 100644
--- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py
+++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py
@@ -209,8 +209,7 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O
def test_base_finetuning_internal_optimizer_metadata(tmpdir):
- """Test the param_groups updates are properly saved within the internal state of the BaseFinetuning
- Callbacks."""
+ """Test the param_groups updates are properly saved within the internal state of the BaseFinetuning Callbacks."""
seed_everything(42)
@@ -325,8 +324,7 @@ def configure_optimizers(self):
def test_callbacks_restore(tmpdir):
- """Test callbacks restore is called after optimizers have been re-created but before optimizer states
- reload."""
+ """Test callbacks restore is called after optimizers have been re-created but before optimizer states reload."""
chk = ModelCheckpoint(dirpath=tmpdir, save_last=True)
model = FinetuningBoringModel()
@@ -400,8 +398,7 @@ def forward(self, x):
def test_callbacks_restore_backbone(tmpdir):
- """Test callbacks restore is called after optimizers have been re-created but before optimizer states
- reload."""
+ """Test callbacks restore is called after optimizers have been re-created but before optimizer states reload."""
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
trainer = Trainer(
diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py
index aba956c414ed9..59c1a145b1a15 100644
--- a/tests/tests_pytorch/callbacks/test_prediction_writer.py
+++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py
@@ -37,8 +37,7 @@ def test_prediction_writer_invalid_write_interval():
def test_prediction_writer_hook_call_intervals():
- """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
- interval."""
+ """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval."""
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()
diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py
index d206fadb593ad..61676864afd66 100644
--- a/tests/tests_pytorch/callbacks/test_pruning.py
+++ b/tests/tests_pytorch/callbacks/test_pruning.py
@@ -281,8 +281,8 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
def test_permanent_when_model_is_saved_multiple_times(
tmpdir, caplog, prune_on_train_epoch_end, save_on_train_epoch_end
):
- """When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not
- the trained model if we want to continue with the same pruning buffers."""
+ """When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not the
+ trained model if we want to continue with the same pruning buffers."""
if prune_on_train_epoch_end and save_on_train_epoch_end:
pytest.xfail(
"Pruning sets the `grad_fn` of the parameters so we can't save"
diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py
index c8172e413ccf2..cccbb3252ec47 100644
--- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py
+++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py
@@ -85,8 +85,8 @@ def mock(key):
def test_model_checkpoint_score_and_ckpt(
tmpdir, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool
):
- """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and
- checkpoint data."""
+ """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint
+ data."""
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
@@ -190,8 +190,8 @@ def on_validation_epoch_end(self):
def test_model_checkpoint_score_and_ckpt_val_check_interval(
tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned
):
- """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and
- checkpoint data with val_check_interval."""
+ """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint
+ data with val_check_interval."""
seed_everything(0)
max_epochs = 3
limit_train_batches = 12
@@ -1131,8 +1131,8 @@ def __init__(self, hparams):
def test_ckpt_version_after_rerun_new_trainer(tmpdir):
- """Check that previous checkpoints are renamed to have the correct version suffix when new trainer instances
- are used."""
+ """Check that previous checkpoints are renamed to have the correct version suffix when new trainer instances are
+ used."""
epochs = 2
for i in range(epochs):
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}")
@@ -1158,8 +1158,8 @@ def test_ckpt_version_after_rerun_new_trainer(tmpdir):
def test_ckpt_version_after_rerun_same_trainer(tmpdir):
- """Check that previous checkpoints are renamed to have the correct version suffix when the same trainer
- instance is used."""
+ """Check that previous checkpoints are renamed to have the correct version suffix when the same trainer instance is
+ used."""
mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test")
mc.STARTING_VERSION = 9
trainer = Trainer(
@@ -1303,8 +1303,8 @@ def on_load_checkpoint(self, *args, **kwargs):
def test_resume_training_preserves_old_ckpt_last(tmpdir):
- """Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from
- the old checkpoint."""
+ """Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from the
+ old checkpoint."""
model = BoringModel()
trainer_kwargs = {
"default_root_dir": tmpdir,
diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py
index fc6762b373745..47a4ea9063eef 100644
--- a/tests/tests_pytorch/conftest.py
+++ b/tests/tests_pytorch/conftest.py
@@ -94,6 +94,7 @@ def restore_signal_handlers():
"""Ensures that signal handlers get restored before the next test runs.
This is a safety net for tests that don't run Trainer's teardown.
+
"""
valid_signals = _SignalConnector._valid_signals()
if not _IS_WINDOWS:
@@ -207,6 +208,7 @@ def caplog(caplog):
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.
Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``.
+
"""
import logging
@@ -248,6 +250,7 @@ def single_process_pg():
"""Initialize the default process group with only the current process for testing purposes.
The process group is destroyed when the with block is exited.
+
"""
if torch.distributed.is_initialized():
raise RuntimeError("Can't use `single_process_pg` when the default process group is already initialized.")
diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py
index 751f580bd4318..d254ec4385fdd 100644
--- a/tests/tests_pytorch/core/test_datamodules.py
+++ b/tests/tests_pytorch/core/test_datamodules.py
@@ -237,8 +237,7 @@ def test_full_loop(tmpdir):
def test_dm_reload_dataloaders_every_n_epochs(tmpdir):
- """Test datamodule, where trainer argument reload_dataloaders_every_n_epochs is set to a non negative
- integer."""
+ """Test datamodule, where trainer argument reload_dataloaders_every_n_epochs is set to a non negative integer."""
class CustomBoringDataModule(BoringDataModule):
def __init__(self):
diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py
index 6d612dd3c9b03..3c57791e26497 100644
--- a/tests/tests_pytorch/core/test_lightning_module.py
+++ b/tests/tests_pytorch/core/test_lightning_module.py
@@ -424,6 +424,7 @@ def test_lightning_module_scriptable():
"""Test that the LightningModule is `torch.jit.script`-able.
Regression test for #15917.
+
"""
model = BoringModel()
trainer = Trainer()
diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py
index 22f93ee7903a7..3d0ee4d7a9498 100644
--- a/tests/tests_pytorch/core/test_lightning_optimizer.py
+++ b/tests/tests_pytorch/core/test_lightning_optimizer.py
@@ -73,6 +73,7 @@ def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(tmpdi
"""Test that the user can use our LightningOptimizer.
Not recommended.
+
"""
class TestModel(BoringModel):
diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py
index e2954b0d73f97..8860160d6f30e 100644
--- a/tests/tests_pytorch/helpers/datasets.py
+++ b/tests/tests_pytorch/helpers/datasets.py
@@ -46,6 +46,7 @@ class MNIST(Dataset):
60000
>>> torch.bincount(dataset.targets)
tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
+
"""
RESOURCES = (
@@ -148,6 +149,7 @@ class TrialMNIST(MNIST):
[0, 1, 2]
>>> torch.bincount(dataset.targets)
tensor([100, 100, 100])
+
"""
def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs):
diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py
index 3fab1b69067cd..3f090b0264d02 100644
--- a/tests/tests_pytorch/loggers/test_all.py
+++ b/tests/tests_pytorch/loggers/test_all.py
@@ -170,6 +170,7 @@ def test_loggers_pickle_all(tmpdir, monkeypatch, logger_class):
"""Test that the logger objects can be pickled.
This test only makes sense if the packages are installed.
+
"""
_patch_comet_atexit(monkeypatch)
try:
@@ -270,8 +271,8 @@ def log_hyperparams(self, params, *args, **kwargs) -> None:
@pytest.mark.parametrize("logger_class", [*ALL_LOGGER_CLASSES_WO_NEPTUNE, CustomLoggerWithoutExperiment])
@RunIf(skip_windows=True)
def test_logger_initialization(tmpdir, monkeypatch, logger_class):
- """Test that loggers get replaced by dummy loggers on global rank > 0 and that the experiment object is
- available at the right time in Trainer."""
+ """Test that loggers get replaced by dummy loggers on global rank > 0 and that the experiment object is available
+ at the right time in Trainer."""
_patch_comet_atexit(monkeypatch)
try:
_test_logger_initialization(tmpdir, logger_class)
diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py
index 239645c2a7e2c..c1f6821311f47 100644
--- a/tests/tests_pytorch/loggers/test_mlflow.py
+++ b/tests/tests_pytorch/loggers/test_mlflow.py
@@ -275,8 +275,7 @@ def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
- """Test that the when logging more than 100 parameters, it will be split into batches of at most 100
- parameters."""
+ """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
logger = MLFlowLogger("test", save_dir=tmpdir)
params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)}
diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py
index 81f73b02d70e9..b2a7133240e2c 100644
--- a/tests/tests_pytorch/loggers/test_neptune.py
+++ b/tests/tests_pytorch/loggers/test_neptune.py
@@ -41,11 +41,12 @@ def create_run_mock(mode="async", **kwargs):
def create_neptune_mock():
- """Mock with provides nice `logger.name` and `logger.version` values. Additionally, it allows `mode` as an
- argument to test different Neptune modes.
+ """Mock with provides nice `logger.name` and `logger.version` values. Additionally, it allows `mode` as an argument
+ to test different Neptune modes.
Mostly due to fact, that windows tests were failing with MagicMock based strings, which were used to create local
directories in FS.
+
"""
return MagicMock(init_run=MagicMock(side_effect=create_run_mock))
@@ -88,6 +89,7 @@ def tmpdir_unittest_fixture(request, tmpdir):
Resources:
* https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture
* https://towardsdatascience.com/mixing-pytest-fixture-and-unittest-testcase-for-selenium-test-9162218e8c8e
+
"""
request.cls.tmpdir = tmpdir
@@ -152,8 +154,8 @@ def test_neptune_pickling(self, neptune):
@patch("lightning.pytorch.loggers.neptune.Run", Run)
@patch("lightning.pytorch.loggers.neptune.Handler", Run)
def test_online_with_wrong_kwargs(self, neptune):
- """Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable
- in init."""
+ """Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable in
+ init."""
with self.assertRaises(ValueError):
NeptuneLogger(run="some string")
diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py
index 3929144b7be9c..1ec532761267e 100644
--- a/tests/tests_pytorch/loggers/test_wandb.py
+++ b/tests/tests_pytorch/loggers/test_wandb.py
@@ -53,6 +53,7 @@ def test_wandb_logger_init(wandb, monkeypatch):
"""Verify that basic functionality of wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here.
+
"""
# test wandb.init called when there is no W&B run
wandb.run = None
@@ -142,6 +143,7 @@ def test_wandb_pickle(wandb, tmpdir):
"""Verify that pickling trainer with wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here.
+
"""
class Experiment:
@@ -373,8 +375,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
@mock.patch("lightning.pytorch.loggers.wandb.Run", new=mock.Mock)
@mock.patch("lightning.pytorch.loggers.wandb.wandb")
def test_wandb_log_model_with_score(wandb, monkeypatch, tmpdir):
- """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar
- tensor."""
+ """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor."""
wandb.run = None
model = BoringModel()
diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py
index 5758ecb1f6300..ea8f760a592d7 100644
--- a/tests/tests_pytorch/loops/test_evaluation_loop.py
+++ b/tests/tests_pytorch/loops/test_evaluation_loop.py
@@ -28,8 +28,7 @@
@mock.patch("lightning.pytorch.loops.evaluation_loop._EvaluationLoop._on_evaluation_epoch_end")
def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
- """Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end`
- hooks."""
+ """Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end` hooks."""
model = BoringModel()
trainer = Trainer(
@@ -112,6 +111,7 @@ def test_memory_consumption_validation(tmpdir):
Cannot run with MPS, since there we can only measure shared memory and not dedicated, which device has how much
memory allocated.
+
"""
def get_memory():
diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py
index f5990f266d478..39dd766f06dd7 100644
--- a/tests/tests_pytorch/loops/test_fetchers.py
+++ b/tests/tests_pytorch/loops/test_fetchers.py
@@ -129,12 +129,14 @@ def get_cycles_per_ms() -> float:
This is to avoid system disturbance that skew the results, e.g. the very first cuda call likely does a bunch of
init, which takes much longer than subsequent calls.
+
"""
def measure() -> float:
"""Measure and return approximate number of cycles per millisecond for `torch.cuda._sleep` Copied from:
https://github.com/pytorch/pytorch/blob/v1.9.0/test/test_cuda.py#L81.
+
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
diff --git a/tests/tests_pytorch/loops/test_progress.py b/tests/tests_pytorch/loops/test_progress.py
index daa381e6f844a..27184d7b17afb 100644
--- a/tests/tests_pytorch/loops/test_progress.py
+++ b/tests/tests_pytorch/loops/test_progress.py
@@ -93,6 +93,7 @@ def test_optimizer_progress_default_factory():
"""Ensure that the defaults are created appropriately.
If `default_factory` was not used, the default would be shared between instances.
+
"""
p1 = _OptimizerProgress()
p2 = _OptimizerProgress()
diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py
index 7d0690c5ac876..7814f56cff5c4 100644
--- a/tests/tests_pytorch/loops/test_training_epoch_loop.py
+++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py
@@ -63,8 +63,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir):
def test_should_stop_early_stopping_conditions_not_met(
caplog, min_epochs, min_steps, current_epoch, global_step, early_stop, epoch_loop_done, raise_info_msg
):
- """Test that checks that info message is logged when users sets `should_stop` but min conditions are not
- met."""
+ """Test that checks that info message is logged when users sets `should_stop` but min conditions are not met."""
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.fit_loop.max_batches = 10
trainer.should_stop = True
@@ -86,6 +85,7 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop
(min_epochs/steps is satisfied).
+
"""
model = BoringModel()
trainer = Trainer(
diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
index b6529f2614301..dc5e12defacea 100644
--- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
+++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py
@@ -138,8 +138,7 @@ def backward(self, loss):
def test_train_step_no_return(tmpdir):
- """Tests that only training_step raises a warning when nothing is returned in case of
- automatic_optimization."""
+ """Tests that only training_step raises a warning when nothing is returned in case of automatic_optimization."""
class TestModel(BoringModel):
def training_step(self, batch):
diff --git a/tests/tests_pytorch/models/test_cpu.py b/tests/tests_pytorch/models/test_cpu.py
index e79f014c7eddd..92423b85cb95e 100644
--- a/tests/tests_pytorch/models/test_cpu.py
+++ b/tests/tests_pytorch/models/test_cpu.py
@@ -155,6 +155,7 @@ def test_lbfgs_cpu_model(tmpdir):
"""Test each of the trainer options.
Testing LBFGS optimizer
+
"""
seed_everything(42)
@@ -247,6 +248,7 @@ def test_running_test_no_val(tmpdir):
"""Verify `test()` works on a model with no `val_dataloader`.
It performs train and test only
+
"""
seed_everything(42)
diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py
index df0046daabb13..61bdfc0305976 100644
--- a/tests/tests_pytorch/models/test_restore.py
+++ b/tests/tests_pytorch/models/test_restore.py
@@ -105,8 +105,7 @@ def test_model_properties_fit_ckpt_path(tmpdir):
@RunIf(sklearn=True)
def test_trainer_properties_restore_ckpt_path(tmpdir):
- """Test that required trainer properties are set correctly when resuming from checkpoint in different
- phases."""
+ """Test that required trainer properties are set correctly when resuming from checkpoint in different phases."""
class CustomClassifModel(ClassificationModel):
def configure_optimizers(self):
diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py
index 9c7c265fb86e7..00595bfd25cba 100644
--- a/tests/tests_pytorch/profilers/test_profiler.py
+++ b/tests/tests_pytorch/profilers/test_profiler.py
@@ -457,6 +457,7 @@ def test_pytorch_profiler_multiple_loggers(tmpdir):
multiple loggers.
See issue #8157.
+
"""
def look_for_trace(trace_dir):
diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
index 2eb367bee2759..73b006e6f7831 100644
--- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
+++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py
@@ -90,8 +90,7 @@ def test_global_state_snapshot():
@pytest.mark.parametrize("fake_node_rank", [0, 1])
@pytest.mark.parametrize("fake_local_rank", [0, 1])
def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir):
- """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
- file."""
+ """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
fake_global_rank = 2 * fake_node_rank + fake_local_rank
@@ -130,8 +129,8 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_transfer_weights(tmpdir, trainer_fn):
- """Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the
- temporary file."""
+ """Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the temporary
+ file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
strategy = DDPStrategy(start_method="spawn")
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn.py b/tests/tests_pytorch/strategies/test_ddp_spawn.py
index 4c03af87fcd3d..74f562f3c1586 100644
--- a/tests/tests_pytorch/strategies/test_ddp_spawn.py
+++ b/tests/tests_pytorch/strategies/test_ddp_spawn.py
@@ -84,8 +84,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
def test_ddp_spawn_find_unused_parameters_exception():
- """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning
- users."""
+ """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users."""
trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp_spawn", max_steps=2)
with pytest.raises(
ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in"
diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py
index 3e18264553769..aeabb378d6a63 100644
--- a/tests/tests_pytorch/strategies/test_ddp_strategy.py
+++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py
@@ -297,8 +297,7 @@ def training_step(self, batch, batch_idx):
def test_ddp_strategy_find_unused_parameters_exception():
- """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning
- users."""
+ """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users."""
trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp", max_steps=2)
with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"):
trainer.fit(UnusedParametersModel())
diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py
index 6e432c0d9c3e7..d4dcb38a39594 100644
--- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py
+++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py
@@ -106,8 +106,7 @@ def deepspeed_zero_config(deepspeed_config):
@RunIf(deepspeed=True)
@pytest.mark.parametrize("strategy", ["deepspeed", DeepSpeedStrategy])
def test_deepspeed_strategy_string(tmpdir, strategy):
- """Test to ensure that the strategy can be passed via string or instance, and parallel devices is correctly
- set."""
+ """Test to ensure that the strategy can be passed via string or instance, and parallel devices is correctly set."""
trainer = Trainer(
accelerator="cpu",
@@ -141,6 +140,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmpdir):
"""Test to ensure precision plugin is also correctly chosen.
DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin
+
"""
trainer = Trainer(
fast_dev_run=True,
@@ -286,8 +286,8 @@ def configure_optimizers(self):
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
def test_deepspeed_config(tmpdir, deepspeed_zero_config):
- """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including
- optimizers/schedulers and saves the model weights to load correctly."""
+ """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
+ and saves the model weights to load correctly."""
class TestCB(Callback):
def on_train_start(self, trainer, pl_module) -> None:
@@ -358,8 +358,8 @@ def on_train_start(self, trainer, pl_module) -> None:
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
@pytest.mark.parametrize("precision", ["fp16", "bf16"])
def test_deepspeed_inference_precision_during_inference(precision, tmpdir):
- """Ensure if we modify the precision for deepspeed and execute inference-only, the deepspeed config contains
- these changes."""
+ """Ensure if we modify the precision for deepspeed and execute inference-only, the deepspeed config contains these
+ changes."""
class TestCB(Callback):
def on_validation_start(self, trainer, pl_module) -> None:
@@ -399,8 +399,8 @@ def test_deepspeed_custom_activation_checkpointing_params(tmpdir):
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir):
- """Ensure if we modify the activation checkpointing parameters, we pass these to
- deepspeed.checkpointing.configure correctly."""
+ """Ensure if we modify the activation checkpointing parameters, we pass these to deepspeed.checkpointing.configure
+ correctly."""
ds = DeepSpeedStrategy(
partition_activations=True,
cpu_checkpointing=True,
@@ -456,8 +456,7 @@ def setup(self, trainer, pl_module, stage=None) -> None:
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu(tmpdir):
- """Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized
- correctly."""
+ """Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
@@ -930,8 +929,8 @@ def on_train_epoch_start(self) -> None:
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_test_rnn(tmpdir):
- """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when
- training with certain layers which will crash with explicit partitioning."""
+ """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when training
+ with certain layers which will crash with explicit partitioning."""
class TestModel(BoringModel):
def __init__(self):
@@ -962,6 +961,7 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmpdir, pl
"""Test to ensure that we setup distributed communication using correctly.
When using windows, ranks environment variables should not be set, and deepspeed should handle this.
+
"""
trainer = Trainer(default_root_dir=tmpdir, strategy=DeepSpeedStrategy(stage=3))
strategy = trainer.strategy
@@ -1087,8 +1087,8 @@ def test_dataloader(self):
@pytest.mark.parametrize("limit_train_batches", [2])
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
def test_scheduler_step_count(mock_step, tmpdir, max_epoch, limit_train_batches, interval):
- """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is
- set to step or epoch."""
+ """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is set to
+ step or epoch."""
class TestModel(BoringModel):
def configure_optimizers(self):
@@ -1122,8 +1122,8 @@ def configure_optimizers(self):
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
def test_deepspeed_configure_gradient_clipping(tmpdir):
- """Test to ensure that a warning is raised when `LightningModule.configure_gradient_clipping` is overridden in
- case of deepspeed."""
+ """Test to ensure that a warning is raised when `LightningModule.configure_gradient_clipping` is overridden in case
+ of deepspeed."""
class TestModel(BoringModel):
def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
@@ -1162,8 +1162,8 @@ def test_deepspeed_gradient_clip_by_value(tmpdir):
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multi_save_same_filepath(tmpdir):
- """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old
- sharded checkpoints."""
+ """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old sharded
+ checkpoints."""
class CustomModel(BoringModel):
def training_step(self, *args, **kwargs):
@@ -1278,6 +1278,7 @@ def test_validate_parallel_devices_indices(device_indices):
"""Test that the strategy validates that it doesn't support selecting specific devices by index.
DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
+
"""
strategy = DeepSpeedStrategy(
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py
index 96273daa58396..381f799b08067 100644
--- a/tests/tests_pytorch/strategies/test_fsdp.py
+++ b/tests/tests_pytorch/strategies/test_fsdp.py
@@ -285,6 +285,7 @@ def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params):
"""Test to ensure that the full state dict is extracted when using FSDP strategy.
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all.
+
"""
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
correct_state_dict = model.state_dict() # State dict before wrapping
@@ -547,6 +548,7 @@ def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
be restored to DDP, it means that the optimizer states were saved correctly.
+
"""
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
@@ -604,6 +606,7 @@ def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params):
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
can be restored to FSDP, it means that the optimizer states were restored correctly.
+
"""
# restore model to ddp
diff --git a/tests/tests_pytorch/strategies/test_single_device_strategy.py b/tests/tests_pytorch/strategies/test_single_device_strategy.py
index 85c92ded505c1..d4648f55f69e9 100644
--- a/tests/tests_pytorch/strategies/test_single_device_strategy.py
+++ b/tests/tests_pytorch/strategies/test_single_device_strategy.py
@@ -45,6 +45,7 @@ def test_single_gpu():
"""Tests if device is set correctly when training and after teardown for single GPU strategy.
Cannot run this test on MPS due to shared memory not allowing dedicated measurements of GPU memory utilization.
+
"""
trainer = Trainer(accelerator="gpu", devices=1, fast_dev_run=True)
# assert training strategy attributes for device setting
diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py
index c9ad6853a1fcd..b074429642abf 100644
--- a/tests/tests_pytorch/test_cli.py
+++ b/tests/tests_pytorch/test_cli.py
@@ -903,6 +903,7 @@ def foo(self, model: LightningModule, x: int, y: float = 1.0):
model: A model
x: The x
y: The y
+
"""
class TestCLI(LightningCLI):
diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py
index 00dab05ad357e..1e70bd0e59bb7 100644
--- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py
@@ -587,8 +587,8 @@ def test_error_raised_with_insufficient_float_limit_train_dataloader():
],
)
def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, dataloader_name, tmpdir):
- """Test that passing `Trainer.method(x_dataloader=None)` with no module-method implementations available raises
- an error."""
+ """Test that passing `Trainer.method(x_dataloader=None)` with no module-method implementations available raises an
+ error."""
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
model = BoringModel()
datamodule = BoringDataModule()
diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py
index 13bbb2243c3ac..9019249c37c33 100644
--- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py
+++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py
@@ -26,6 +26,7 @@ class AllRankLogger(Logger):
"""Logger to test all-rank logging (i.e. not just rank 0).
Logs are saved to local variable `logs`.
+
"""
def __init__(self):
@@ -102,6 +103,7 @@ def test_first_logger_call_in_subprocess(tmpdir):
"""Test that the Trainer does not call the logger too early.
Only when the worker processes are initialized do we have access to the rank and know which one is the main process.
+
"""
class LoggerCallsObserver(Callback):
diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py
index 0dc0df1f9040a..2623108e5b9a1 100644
--- a/tests/tests_pytorch/trainer/test_dataloaders.py
+++ b/tests/tests_pytorch/trainer/test_dataloaders.py
@@ -675,8 +675,7 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
- """Test that a warning message is shown if the dataloader length is too short for the chosen logging
- interval."""
+ """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval."""
model = BoringModel()
dataloader = DataLoader(RandomDataset(32, length=10))
model.train_dataloader = lambda: dataloader
@@ -847,8 +846,7 @@ def train_dataloader(self):
@RunIf(min_cuda_gpus=2, skip_windows=True)
def test_dataloader_distributed_sampler_already_attached(tmpdir):
- """Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on
- dataloader."""
+ """Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader."""
seed_everything(123)
model = ModelWithDataLoaderDistributedSampler()
trainer = Trainer(
@@ -1209,8 +1207,8 @@ def side_effect_request_dataloader(ds):
def test_dataloaders_reset_and_attach(tmpdir):
- """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching
- the new one."""
+ """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching the
+ new one."""
# the assertions compare the datasets and not dataloaders since we patch and replace the samplers
dataloader_0 = DataLoader(dataset=RandomDataset(32, 64))
dataloader_1 = DataLoader(dataset=RandomDataset(32, 64))
diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py
index 346ef9a84e709..20438af3517dd 100644
--- a/tests/tests_pytorch/trainer/test_trainer.py
+++ b/tests/tests_pytorch/trainer/test_trainer.py
@@ -67,8 +67,7 @@
def test_trainer_error_when_input_not_lightning_module():
- """Test that a useful error gets raised when the Trainer methods receive something other than a
- LightningModule."""
+ """Test that a useful error gets raised when the Trainer methods receive something other than a LightningModule."""
trainer = Trainer()
for method in ("fit", "validate", "test", "predict"):
@@ -347,8 +346,8 @@ def mock_save_function(filepath, *args):
def test_model_checkpoint_only_weights(tmpdir):
- """Tests use case where ModelCheckpoint is configured to save only model weights, and user tries to load
- checkpoint to resume training."""
+ """Tests use case where ModelCheckpoint is configured to save only model weights, and user tries to load checkpoint
+ to resume training."""
model = BoringModel()
trainer = Trainer(
@@ -1447,8 +1446,8 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir):
@pytest.mark.parametrize(("max_steps", "max_epochs", "global_step"), [(10, 5, 10), (20, None, 20)])
def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epochs, global_step):
- """Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`,
- and disabled if the limit is reached."""
+ """Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`, and
+ disabled if the limit is reached."""
dataset_len = 200
batch_size = 10
diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py
index 87a7412396b91..bee5c679edca0 100644
--- a/tests/tests_pytorch/tuner/test_lr_finder.py
+++ b/tests/tests_pytorch/tuner/test_lr_finder.py
@@ -217,8 +217,7 @@ def test_datamodule_parameter(tmpdir):
def test_accumulation_and_early_stopping(tmpdir):
- """Test that early stopping of learning rate finder works, and that accumulation also works for this
- feature."""
+ """Test that early stopping of learning rate finder works, and that accumulation also works for this feature."""
seed_everything(1)
class TestModel(BoringModel):
diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py
index 4e88953286885..3f32f8231dd4a 100644
--- a/tests/tests_pytorch/tuner/test_scale_batch_size.py
+++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py
@@ -231,8 +231,7 @@ def test_call_to_trainer_method(tmpdir, scale_method):
def test_error_on_dataloader_passed_to_fit(tmpdir):
- """Verify that when the auto-scale batch size feature raises an error if a train dataloader is passed to
- fit."""
+ """Verify that when the auto-scale batch size feature raises an error if a train dataloader is passed to fit."""
# only train passed to fit
model = BatchSizeModel(batch_size=2)
diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py
index f8f2be9786a79..c9cf6cd4de76b 100644
--- a/tests/tests_pytorch/utilities/migration/test_utils.py
+++ b/tests/tests_pytorch/utilities/migration/test_utils.py
@@ -178,8 +178,8 @@ def upgrade(ckpt):
def test_migrate_checkpoint_too_new():
- """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer
- version of Lightning than installed."""
+ """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer version
+ of Lightning than installed."""
super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123}
with pytest.warns(
PossibleUserWarning, match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}"
diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py
index 7109523b378a9..9f247616fd2e9 100644
--- a/tests/tests_pytorch/utilities/test_combined_loader.py
+++ b/tests/tests_pytorch/utilities/test_combined_loader.py
@@ -403,8 +403,7 @@ def __len__(self):
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
def test_combined_data_loader_validation_test(use_distributed_sampler):
- """This test makes sure distributed sampler has been properly injected in dataloaders when using
- CombinedLoader."""
+ """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader."""
class CustomDataset(Dataset):
def __init__(self, data):
diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py
index 3c6d2ababbe0d..930650dcb2789 100644
--- a/tests/tests_pytorch/utilities/test_data.py
+++ b/tests/tests_pytorch/utilities/test_data.py
@@ -139,11 +139,12 @@ def __init__(self, randomize, *args, **kwargs):
@pytest.mark.parametrize("predicting", [True, False])
def test_custom_batch_sampler(predicting):
- """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to
- properly reinstantiate the class, is invoked properly.
+ """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly
+ reinstantiate the class, is invoked properly.
It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
+
"""
class MyBatchSampler(BatchSampler):
@@ -189,8 +190,8 @@ def __init__(self, sampler, extra_arg, drop_last=True):
def test_custom_batch_sampler_no_drop_last():
- """Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and
- we want to reset it."""
+ """Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and we
+ want to reset it."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler with extra argument, but without `drop_last`
@@ -217,8 +218,7 @@ def __init__(self, sampler, extra_arg):
def test_custom_batch_sampler_no_sampler():
- """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler
- argument."""
+ """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument."""
class MyBatchSampler(BatchSampler):
# Custom batch sampler, without sampler argument.
@@ -269,10 +269,10 @@ def test_dataloader_kwargs_replacement_with_iterable_dataset(mode):
def test_dataloader_kwargs_replacement_with_array_default_comparison():
- """Test that the comparison of attributes and default argument values works with arrays (truth value
- ambiguous).
+ """Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous).
Regression test for issue #15408.
+
"""
dataset = RandomDataset(5, 100)
diff --git a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py
index 146ab1aa6601b..919acee7e4af1 100644
--- a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py
+++ b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py
@@ -22,8 +22,8 @@
@RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True)
def test_deepspeed_summary(tmpdir):
- """Test to ensure that the summary contains the correct values when stage 3 is enabled and that the trainer
- enables the `DeepSpeedSummary` when DeepSpeed is used."""
+ """Test to ensure that the summary contains the correct values when stage 3 is enabled and that the trainer enables
+ the `DeepSpeedSummary` when DeepSpeed is used."""
model = BoringModel()
total_parameters = sum(x.numel() for x in model.parameters())
diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py
index 0fb31b59ba254..ae3c092907db4 100644
--- a/tests/tests_pytorch/utilities/test_imports.py
+++ b/tests/tests_pytorch/utilities/test_imports.py
@@ -63,10 +63,11 @@ def new_fn(*args, **kwargs):
@pytest.fixture()
def clean_import():
- """This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the
- current state of the imported modules.
+ """This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the current
+ state of the imported modules.
Afterwards, it restores the original state of the modules.
+
"""
import sys
@@ -108,6 +109,7 @@ def test_import_with_unavailable_dependencies(patch_name, new_fn, to_import, cle
When the patch is applied and the module is imported, it should not raise any errors. The list of cases to check was
compiled by finding else branches of top-level if statements checking for the availability of the module and
performing imports.
+
"""
with mock.patch(patch_name, new=new_fn):
importlib.import_module(to_import)
diff --git a/tests/tests_pytorch/utilities/test_warnings.py b/tests/tests_pytorch/utilities/test_warnings.py
index 78f0570ee2947..04c4d50ae8e8f 100644
--- a/tests/tests_pytorch/utilities/test_warnings.py
+++ b/tests/tests_pytorch/utilities/test_warnings.py
@@ -14,6 +14,7 @@
"""Test that the warnings actually appear and they have the correct `stacklevel`
Needs to be run outside of `pytest` as it captures all the warnings.
+
"""
from contextlib import redirect_stderr
from io import StringIO