From 6749f4102f991a0d7e3e0f77102a278e28755698 Mon Sep 17 00:00:00 2001 From: Quasar-Kim Date: Tue, 23 May 2023 00:29:29 +0900 Subject: [PATCH] Add `load_legacy_checkpoint` function --- .../checkpointing/test_legacy_checkpoints.py | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index 36dddf42f7d87..1061778843a4d 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -14,7 +14,6 @@ import glob import os import sys -import threading from unittest.mock import patch import pytest @@ -80,8 +79,32 @@ def load_model(): from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(path_ckpt) - + _ = torch.load(path_ckpt) + + with patch("sys.path", [PATH_LEGACY] + sys.path): + t1 = ThreadExceptionHandler(target=load_model) + t2 = ThreadExceptionHandler(target=load_model) + + t1.start() + t2.start() + + t1.join() + t2.join() + + +@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) +@RunIf(sklearn=True) +def test_load_legacy_checkpoint_threading(tmpdir, pl_version: str): + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) + assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' + path_ckpt = path_ckpts[-1] + + def load_model(): + from lightning.pytorch.utilities.migration.utils import load_legacy_checkpoint + + _ = load_legacy_checkpoint(path_ckpt) + with patch("sys.path", [PATH_LEGACY] + sys.path): t1 = ThreadExceptionHandler(target=load_model) t2 = ThreadExceptionHandler(target=load_model)