Skip to content

Commit

Permalink
fix(tests): update tests after torch 2.4.1 (#20302)
Browse files Browse the repository at this point in the history
* update

* test_loggers_pickle_all

* more...

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit d1ca3c6)
  • Loading branch information
Borda committed Sep 26, 2024
1 parent 1cc1c0d commit 143b6d0
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 20 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ ignore = [
"S108",
"E203", # conflicts with black
]
ignore-init-module-imports = true

[tool.ruff.lint.per-file-ignores]
".actions/*" = ["S101", "S310"]
Expand Down
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mypy==1.11.0
torch==2.4.0
torch==2.4.1

types-Markdown
types-PyYAML
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")

_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import cloudpickle
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -193,12 +193,12 @@ def test_pickling():
early_stopping = EarlyStopping(monitor="foo")

early_stopping_pickled = pickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)

early_stopping_pickled = cloudpickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)

Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import yaml
from jsonargparse import ArgumentParser
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -352,12 +352,12 @@ def test_pickling(tmp_path):
ckpt = ModelCheckpoint(dirpath=tmp_path)

ckpt_pickled = pickle.dumps(ckpt)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)

ckpt_pickled = cloudpickle.dumps(ckpt)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)

Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import lightning.pytorch as pl
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import OnExceptionCheckpoint
Expand Down Expand Up @@ -254,7 +254,7 @@ def lightning_log(fx, *args, **kwargs):
}

# make sure can be pickled
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
pickle.loads(pickle.dumps(result))
# make sure can be torch.loaded
filepath = str(tmp_path / "result")
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/helpers/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import cloudpickle
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0

from tests_pytorch import _PATH_DATASETS
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
Expand All @@ -44,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args):
mnist = dataset_cls(**args)

mnist_pickled = pickle.dumps(mnist)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
pickle.loads(mnist_pickled)

mnist_pickled = cloudpickle.dumps(mnist)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
cloudpickle.loads(mnist_pickled)
10 changes: 7 additions & 3 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import (
Expand Down Expand Up @@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class):
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")


def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
"""Verify that pickling trainer with logger works."""
_patch_comet_atexit(monkeypatch)

Expand All @@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
trainer = Trainer(max_epochs=1, logger=logger)
pkl_bytes = pickle.dumps(trainer)

with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with (
pytest.warns(FutureWarning, match="`weights_only=False`")
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
else nullcontext()
):
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/loggers/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.fabric.utilities.logger import _convert_params, _sanitize_params
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_multiple_loggers_pickle(tmp_path):

trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytest
import yaml
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningCLI
Expand Down Expand Up @@ -162,7 +162,7 @@ def name(self):
assert trainer.logger.experiment, "missing experiment"
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)

assert os.environ["WANDB_MODE"] == "dryrun"
Expand Down

0 comments on commit 143b6d0

Please sign in to comment.