Skip to content

Commit

Permalink
Set torch.load(weights_only=) in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 7, 2024
1 parent a6562b4 commit 1eff20f
Show file tree
Hide file tree
Showing 30 changed files with 87 additions and 66 deletions.
4 changes: 2 additions & 2 deletions tests/tests_fabric/strategies/test_deepspeed_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,11 @@ def _assert_saved_model_is_equal(fabric, model, checkpoint_path):
single_ckpt_path = checkpoint_path / "single_model.pt"
# the tag is hardcoded in DeepSpeedStrategy
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path, tag="checkpoint")
state_dict = torch.load(single_ckpt_path)
state_dict = torch.load(single_ckpt_path, weights_only=False)
else:
# 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy
single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt"
state_dict = torch.load(single_ckpt_path)["module"]
state_dict = torch.load(single_ckpt_path, weights_only=False)["module"]

model = model.cpu()

Expand Down
8 changes: 4 additions & 4 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_save_full_state_dict(tmp_path):
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)

checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)
assert checkpoint["steps"] == 1
loaded_state_dict = checkpoint["model"]

Expand Down Expand Up @@ -248,7 +248,7 @@ def test_save_full_state_dict(tmp_path):
# get optimizer state after loading
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
fabric.save(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 2})
optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"]
optimizer_state_after = torch.load(normal_checkpoint_path, weights_only=True)["optimizer"]
optimizer_state_after = FullyShardedDataParallel.rekey_optim_state_dict(
optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=trainer.model
)
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path):
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
if fabric.global_rank == 0:
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)
torch.save(checkpoint["model"], raw_checkpoint_path)
fabric.barrier()

Expand Down Expand Up @@ -485,7 +485,7 @@ def test_save_filter(tmp_path):

checkpoint_path = tmp_path / "full.pth"
fabric.save(checkpoint_path, state, filter=filter)
checkpoint = torch.load(checkpoint_path)["model"]
checkpoint = torch.load(checkpoint_path, weights_only=True)["model"]
assert set(checkpoint) == {"bias"}
assert type(checkpoint["bias"]) is torch.Tensor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_save_full_state_dict(tmp_path):
state = {"model": model, "optimizer": optimizer, "steps": 1}
fabric.save(checkpoint_path, state)

checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)
assert checkpoint["steps"] == 1
loaded_state_dict = checkpoint["model"]

Expand Down Expand Up @@ -369,7 +369,7 @@ def test_save_full_state_dict(tmp_path):
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
fabric.save(normal_checkpoint_path, {"model": model, "optimizer": optimizer, "steps": 2})

optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"]
optimizer_state_after = torch.load(normal_checkpoint_path, weights_only=True)["optimizer"]
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
assert torch.equal(
optimizer_state_after["state"][0]["exp_avg"],
Expand Down Expand Up @@ -433,7 +433,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path):
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
if fabric.global_rank == 0:
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)
torch.save(checkpoint["model"], raw_checkpoint_path)
fabric.barrier()

Expand Down Expand Up @@ -519,7 +519,7 @@ def test_save_filter(tmp_path):

checkpoint_path = tmp_path / "full.pth"
fabric.save(checkpoint_path, state, filter=filter)
checkpoint = torch.load(checkpoint_path)["model"]
checkpoint = torch.load(checkpoint_path, weights_only=True)["model"]
assert set(checkpoint) == {"w1.bias", "w2.bias", "w3.bias"}
assert type(checkpoint["w1.bias"]) is torch.Tensor

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/accelerators/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_resume_training_on_cpu(tmp_path):
model_path = trainer.checkpoint_callback.best_model_path

# Verify saved Tensors are on CPU
ckpt = torch.load(model_path)
ckpt = torch.load(model_path, weights_only=True)
weight_tensor = list(ckpt["state_dict"].values())[0]
assert weight_tensor.device == torch.device("cpu")

Expand Down
10 changes: 7 additions & 3 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import math
import os
import pickle
from contextlib import nullcontext
from typing import List, Optional
from unittest import mock
from unittest.mock import Mock

import cloudpickle
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -82,7 +84,7 @@ def test_resume_early_stopping_from_checkpoint(tmp_path):

checkpoint_filepath = checkpoint_callback.kth_best_model_path
# ensure state is persisted properly
checkpoint = torch.load(checkpoint_filepath)
checkpoint = torch.load(checkpoint_filepath, weights_only=True)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]]
assert len(early_stop_callback.saved_states) == 4
Expand Down Expand Up @@ -191,11 +193,13 @@ def test_pickling():
early_stopping = EarlyStopping(monitor="foo")

early_stopping_pickled = pickle.dumps(early_stopping)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)

early_stopping_pickled = cloudpickle.dumps(early_stopping)
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)


Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo
filepath = str(tmp_path / "foo.ckpt")
trainer.save_checkpoint(filepath)

model.load_state_dict(torch.load(filepath), strict=False)
model.load_state_dict(torch.load(filepath, weights_only=True), strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def load_model():
from lightning.pytorch.utilities.migration import pl_legacy_patch

with pl_legacy_patch():
_ = torch.load(path_ckpt)
_ = torch.load(path_ckpt, weights_only=False)

with patch("sys.path", [PATH_LEGACY] + sys.path):
t1 = ThreadExceptionHandler(target=load_model)
Expand Down
18 changes: 11 additions & 7 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
import time
from argparse import Namespace
from contextlib import nullcontext
from datetime import timedelta
from inspect import signature
from pathlib import Path
Expand All @@ -31,6 +32,7 @@
import yaml
from jsonargparse import ArgumentParser
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -350,11 +352,13 @@ def test_pickling(tmp_path):
ckpt = ModelCheckpoint(dirpath=tmp_path)

ckpt_pickled = pickle.dumps(ckpt)
ckpt_loaded = pickle.loads(ckpt_pickled)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)

ckpt_pickled = cloudpickle.dumps(ckpt)
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)


Expand Down Expand Up @@ -920,8 +924,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmp_path):
assert os.path.isfile(path_last_epoch)
assert os.path.isfile(path_last)

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
ckpt_last_epoch = torch.load(path_last_epoch, weights_only=True)
ckpt_last = torch.load(path_last, weights_only=True)

assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]
Expand Down Expand Up @@ -1168,7 +1172,7 @@ def training_step(self, *args):
)
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(ckpt) for ckpt in tmp_path.iterdir()]
ckpts = [torch.load(ckpt, weights_only=True) for ckpt in tmp_path.iterdir()]
ckpts = [
ckpt["callbacks"][
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
Expand Down Expand Up @@ -1452,7 +1456,7 @@ def test_save_last_saves_correct_last_model_path(tmp_path):
expected = "foo=1-last.ckpt"
assert os.listdir(tmp_path) == [expected]
full_path = tmp_path / expected
ckpt = torch.load(full_path)
ckpt = torch.load(full_path, weights_only=True)
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == str(full_path)


Expand Down Expand Up @@ -1484,7 +1488,7 @@ def test_none_monitor_saves_correct_best_model_path(tmp_path):
expected = "epoch=0-step=0.ckpt"
assert os.listdir(tmp_path) == [expected]
full_path = str(tmp_path / expected)
ckpt = torch.load(full_path)
ckpt = torch.load(full_path, weights_only=True)
assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path


Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/checkpointing/test_torch_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_model_torch_save(tmp_path):
# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)
trainer = torch.load(temp_path)
trainer = torch.load(temp_path, weights_only=False)


@RunIf(skip_windows=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# fit model
trainer.fit(model, datamodule=dm)
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path, weights_only=True)
assert dm.__class__.__qualname__ in checkpoint
assert checkpoint[dm.__class__.__qualname__] == {"my": "state_dict"}

Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import lightning.pytorch as pl
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import OnExceptionCheckpoint
Expand Down Expand Up @@ -253,11 +254,12 @@ def lightning_log(fx, *args, **kwargs):
}

# make sure can be pickled
pickle.loads(pickle.dumps(result))
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
pickle.loads(pickle.dumps(result))
# make sure can be torch.loaded
filepath = str(tmp_path / "result")
torch.save(result, filepath)
torch.load(filepath)
torch.load(filepath, weights_only=False)

# assert metric state reset to default values
result.reset()
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_pytorch/core/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_load_from_checkpoint_map_location_automatic(accelerator, tmp_path, monk
create_boring_checkpoint(tmp_path, BoringModel(), accelerator=accelerator)

# The checkpoint contains tensors with storage tag on the accelerator
checkpoint = torch.load(f"{tmp_path}/checkpoint.ckpt")
checkpoint = torch.load(f"{tmp_path}/checkpoint.ckpt", weights_only=True)
assert checkpoint["state_dict"]["layer.weight"].device.type.startswith(accelerator)

# Pretend that the accelerator is not available
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_load_from_checkpoint_warn_on_empty_state_dict(tmp_path):
"""Test that checkpoints can be loaded with an empty state dict and that the appropriate warning is raised."""
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
# Now edit so the state_dict is empty
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
checkpoint = torch.load(tmp_path / "checkpoint.ckpt", weights_only=True)
checkpoint["state_dict"] = {}
torch.save(checkpoint, tmp_path / "checkpoint.ckpt")

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/helpers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _try_load(path_data, trials: int = 30, delta: float = 1.0):
assert os.path.isfile(path_data), f"missing file: {path_data}"
for _ in range(trials):
try:
res = torch.load(path_data)
res = torch.load(path_data, weights_only=True)
# todo: specify the possible exception
except Exception as ex:
exception = ex
Expand Down
10 changes: 6 additions & 4 deletions tests/tests_pytorch/helpers/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from contextlib import nullcontext

import cloudpickle
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4

from tests_pytorch import _PATH_DATASETS
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
Expand All @@ -42,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args):
mnist = dataset_cls(**args)

mnist_pickled = pickle.dumps(mnist)
pickle.loads(mnist_pickled)
# assert vars(mnist) == vars(mnist_loaded)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
pickle.loads(mnist_pickled)

mnist_pickled = cloudpickle.dumps(mnist)
cloudpickle.loads(mnist_pickled)
# assert vars(mnist) == vars(mnist_loaded)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
cloudpickle.loads(mnist_pickled)
5 changes: 4 additions & 1 deletion tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import inspect
import os
import pickle
from contextlib import nullcontext
from unittest import mock
from unittest.mock import ANY, Mock

import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import (
Expand Down Expand Up @@ -182,7 +184,8 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
trainer = Trainer(max_epochs=1, logger=logger)
pkl_bytes = pickle.dumps(trainer)

trainer2 = pickle.loads(pkl_bytes)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

# make sure we restored properly
Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/loggers/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.
import pickle
from argparse import Namespace
from contextlib import nullcontext
from copy import deepcopy
from typing import Any, Dict, Optional
from unittest.mock import patch

import numpy as np
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.logger import _convert_params, _sanitize_params
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
Expand Down Expand Up @@ -122,7 +124,8 @@ def test_multiple_loggers_pickle(tmp_path):

trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)

Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
import os
import pickle
from contextlib import nullcontext
from pathlib import Path
from unittest import mock

import pytest
import yaml
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningCLI
Expand Down Expand Up @@ -160,7 +162,8 @@ def name(self):
assert trainer.logger.experiment, "missing experiment"
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)

assert os.environ["WANDB_MODE"] == "dryrun"
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
Expand Down
Loading

0 comments on commit 1eff20f

Please sign in to comment.