Skip to content

Commit

Permalink
Add load_legacy_checkpoint function
Browse files Browse the repository at this point in the history
  • Loading branch information
Quasar-Kim committed May 22, 2023
1 parent 093bc11 commit 6749f41
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import glob
import os
import sys
import threading
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -80,8 +79,32 @@ def load_model():
from lightning.pytorch.utilities.migration import pl_legacy_patch

with pl_legacy_patch():
_ = torch.load(path_ckpt)

_ = torch.load(path_ckpt)

with patch("sys.path", [PATH_LEGACY] + sys.path):
t1 = ThreadExceptionHandler(target=load_model)
t2 = ThreadExceptionHandler(target=load_model)

t1.start()
t2.start()

t1.join()
t2.join()


@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
@RunIf(sklearn=True)
def test_load_legacy_checkpoint_threading(tmpdir, pl_version: str):
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}")))
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
path_ckpt = path_ckpts[-1]

def load_model():
from lightning.pytorch.utilities.migration.utils import load_legacy_checkpoint

_ = load_legacy_checkpoint(path_ckpt)

with patch("sys.path", [PATH_LEGACY] + sys.path):
t1 = ThreadExceptionHandler(target=load_model)
t2 = ThreadExceptionHandler(target=load_model)
Expand Down

0 comments on commit 6749f41

Please sign in to comment.