Skip to content

Commit

Permalink
Re-introduce thread locking
Browse files Browse the repository at this point in the history
  • Loading branch information
Quasar-Kim committed May 30, 2023
1 parent 7a402ad commit 6227c43
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
9 changes: 7 additions & 2 deletions src/lightning/pytorch/utilities/migration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging
import os
import sys
import threading
from types import ModuleType, TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -69,6 +70,9 @@ def migrate_checkpoint(
return checkpoint, applied_migrations


_lock = threading.Lock()


class pl_legacy_patch:
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
unpickling old checkpoints. The following patches apply.
Expand All @@ -85,6 +89,7 @@ class pl_legacy_patch:
"""

def __enter__(self) -> "pl_legacy_patch":
_lock.acquire()
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils")
sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module
Expand All @@ -102,8 +107,8 @@ def __exit__(
) -> None:
if hasattr(pl.utilities.argparse, "_gpus_arg_default"):
delattr(pl.utilities.argparse, "_gpus_arg_default")
if "lightning.pytorch.utilities.argparse_utils" in sys.modules:
del sys.modules["lightning.pytorch.utilities.argparse_utils"]
del sys.modules["lightning.pytorch.utilities.argparse_utils"]
_lock.release()


def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT:
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_pytorch/helpers/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
class ThreadExceptionHandler(Thread):
"""Adopted from https://stackoverflow.com/a/67022927."""

def __init__(self, target):
super().__init__(target=target)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exception = None

def run(self):
try:
self._target()
super().run()
except Exception as e:
self.exception = e

Expand Down

0 comments on commit 6227c43

Please sign in to comment.