From 143b6d092106d0c4b067ea56f937c736855ec074 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:52:22 +0200 Subject: [PATCH] fix(tests): update tests after torch 2.4.1 (#20302) * update * test_loggers_pickle_all * more... * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit d1ca3c6e096f399ef16b339385acbdfaf564f15b) --- pyproject.toml | 1 - requirements/typing.txt | 2 +- src/lightning/fabric/utilities/imports.py | 2 ++ tests/tests_pytorch/callbacks/test_early_stopping.py | 6 +++--- .../checkpointing/test_model_checkpoint.py | 6 +++--- .../core/test_metric_result_integration.py | 4 ++-- tests/tests_pytorch/helpers/test_datasets.py | 6 +++--- tests/tests_pytorch/loggers/test_all.py | 10 +++++++--- tests/tests_pytorch/loggers/test_logger.py | 4 ++-- tests/tests_pytorch/loggers/test_wandb.py | 4 ++-- 10 files changed, 25 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6edd6d1a8f11f..da4cd7f197d5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ ignore = [ "S108", "E203", # conflicts with black ] -ignore-init-module-imports = true [tool.ruff.lint.per-file-ignores] ".actions/*" = ["S101", "S310"] diff --git a/requirements/typing.txt b/requirements/typing.txt index 9f1952605babc..0323edfd6098a 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.11.0 -torch==2.4.0 +torch==2.4.1 types-Markdown types-PyYAML diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 4dbd57e531859..a1c5a6f6dcd1b 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -31,7 +31,9 @@ _TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0") +_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") +_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 633c1dc0853e0..b7e52ee549bcc 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -23,7 +23,7 @@ import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -193,12 +193,12 @@ def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): early_stopping_loaded = pickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) early_stopping_pickled = cloudpickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 8ef78a742f9a7..97d8d3c4d0e4a 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -32,7 +32,7 @@ import yaml from jsonargparse import ArgumentParser from lightning.fabric.utilities.cloud_io import _load as pl_load -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -352,12 +352,12 @@ def test_pickling(tmp_path): ckpt = ModelCheckpoint(dirpath=tmp_path) ckpt_pickled = pickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): ckpt_loaded = pickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) ckpt_pickled = cloudpickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): ckpt_loaded = cloudpickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 9818f9807ae6d..ef340d1e17ea9 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -19,7 +19,7 @@ import lightning.pytorch as pl import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint @@ -254,7 +254,7 @@ def lightning_log(fx, *args, **kwargs): } # make sure can be pickled - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmp_path / "result") diff --git a/tests/tests_pytorch/helpers/test_datasets.py b/tests/tests_pytorch/helpers/test_datasets.py index ddc20c29e62e8..98d77a6d9a8ad 100644 --- a/tests/tests_pytorch/helpers/test_datasets.py +++ b/tests/tests_pytorch/helpers/test_datasets.py @@ -17,7 +17,7 @@ import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST @@ -44,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args): mnist = dataset_cls(**args) mnist_pickled = pickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): pickle.loads(mnist_pickled) mnist_pickled = cloudpickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): cloudpickle.loads(mnist_pickled) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 503e49fe6cdad..c5b07562afb0a 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -20,7 +20,7 @@ import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1 from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import ( @@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class): pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.") -def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): +def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger): """Verify that pickling trainer with logger works.""" _patch_comet_atexit(monkeypatch) @@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): trainer = Trainer(max_epochs=1, logger=logger) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with ( + pytest.warns(FutureWarning, match="`weights_only=False`") + if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger)) + else nullcontext() + ): trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index 7b384890f6148..de0028000cd9f 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -21,7 +21,7 @@ import numpy as np import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.fabric.utilities.logger import _convert_params, _sanitize_params from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel @@ -124,7 +124,7 @@ def test_multiple_loggers_pickle(tmp_path): trainer = Trainer(logger=[logger1, logger2]) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): trainer2 = pickle.loads(pkl_bytes) for logger in trainer2.loggers: logger.log_metrics({"acc": 1.0}, 0) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index e9195f628348b..4e3fbb287a1f9 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -19,7 +19,7 @@ import pytest import yaml -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI @@ -162,7 +162,7 @@ def name(self): assert trainer.logger.experiment, "missing experiment" assert trainer.log_dir == logger.save_dir pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): trainer2 = pickle.loads(pkl_bytes) assert os.environ["WANDB_MODE"] == "dryrun"