Skip to content

Commit

Permalink
Fix issue #1249 pytorch-lightning patches (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-gardner1 authored May 10, 2024
1 parent 2fbd864 commit 66a7f56
Showing 1 changed file with 64 additions and 69 deletions.
133 changes: 64 additions & 69 deletions clearml/binding/frameworks/pytorch_bind.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
from typing import Any, Callable, Literal

import six
import threading
import importlib

from pathlib2 import Path

Expand Down Expand Up @@ -108,6 +110,65 @@ def _mmcv_save_checkpoint(original_fn, model, filename, *args, **kwargs):
del PatchPyTorchModelIO._checkpoint_filename[tid]
return ret

@staticmethod
def _patch_lightning_io_internal(lightning_name: Literal["lightning", "pytorch_lightning"]):

try:
pytorch_lightning = importlib.import_module(lightning_name)
except ImportError:
# lightning is not installed
# Nothing to do
return
if lightning_name == "lightning":
pytorch_lightning = pytorch_lightning.pytorch

def patch_method(cls: type, method_name: str,
patched_method: Callable[..., Any]) -> None:
"""
Patch a method of a class if it exists.
Otherwise, no effect.
"""
try:
method = getattr(cls, method_name)
except AttributeError:
# the method is not defined on the given class
pass
else:
setattr(cls, method_name,
_patched_call(method, patched_method))

patch_method(pytorch_lightning.trainer.Trainer, "save_checkpoint",
PatchPyTorchModelIO._save)

patch_method(pytorch_lightning.trainer.Trainer, "restore",
PatchPyTorchModelIO._load_from_obj)

try:
checkpoint_connector = pytorch_lightning.trainer.connectors.checkpoint_connector
except AttributeError:
# checkpoint_connector does not yet exist; lightning version is < 0.10.0
# Nothing left to do
return

try:
CheckpointConnector = checkpoint_connector._CheckpointConnector
except AttributeError:
# CheckpointConnector has not yet been made protected
# lighting version is < 2.0.0
try:
CheckpointConnector = checkpoint_connector.CheckpointConnector
except AttributeError:
# Unexpected future breaking change in lightning
# No way to automatically handle
return

patch_method(CheckpointConnector, "save_checkpoint",
PatchPyTorchModelIO._save)

patch_method(CheckpointConnector, "restore",
PatchPyTorchModelIO._load_from_obj)

@staticmethod
def _patch_lightning_io():
if PatchPyTorchModelIO.__patched_lightning:
Expand All @@ -118,41 +179,7 @@ def _patch_lightning_io():

PatchPyTorchModelIO.__patched_lightning = True

# noinspection PyBroadException
try:
import lightning # noqa

lightning.pytorch.trainer.Trainer.save_checkpoint = _patched_call(
lightning.pytorch.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save
) # noqa

lightning.pytorch.trainer.Trainer.restore = _patched_call(
lightning.pytorch.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj
) # noqa
except ImportError:
pass
except Exception:
pass

# noinspection PyBroadException
try:
import lightning # noqa

# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
PatchPyTorchModelIO._save,
) # noqa

# noinspection PyUnresolvedReferences
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = _patched_call(
lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
PatchPyTorchModelIO._load_from_obj,
) # noqa
except ImportError:
pass
except Exception:
pass
PatchPyTorchModelIO._patch_lightning_io_internal("lightning")

@staticmethod
def _patch_pytorch_lightning_io():
Expand All @@ -164,39 +191,7 @@ def _patch_pytorch_lightning_io():

PatchPyTorchModelIO.__patched_pytorch_lightning = True

# noinspection PyBroadException
try:
import pytorch_lightning # noqa

pytorch_lightning.trainer.Trainer.save_checkpoint = _patched_call(
pytorch_lightning.trainer.Trainer.save_checkpoint, PatchPyTorchModelIO._save) # noqa

pytorch_lightning.trainer.Trainer.restore = _patched_call(
pytorch_lightning.trainer.Trainer.restore, PatchPyTorchModelIO._load_from_obj) # noqa
except ImportError:
pass
except Exception:
pass

# noinspection PyBroadException
try:
import pytorch_lightning # noqa

# noinspection PyUnresolvedReferences
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint = \
_patched_call(
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint,
PatchPyTorchModelIO._save) # noqa

# noinspection PyUnresolvedReferences
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore = \
_patched_call(
pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector.restore,
PatchPyTorchModelIO._load_from_obj) # noqa
except ImportError:
pass
except Exception:
pass
PatchPyTorchModelIO._patch_lightning_io_internal("pytorch_lightning")

@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
Expand Down Expand Up @@ -334,4 +329,4 @@ def _load_from_obj(original_fn, obj, f, *args, **kwargs):
def __get_cached_checkpoint_filename():
tid = threading.current_thread().ident
checkpoint_filename = PatchPyTorchModelIO._checkpoint_filename.get(tid)
return checkpoint_filename or None
return checkpoint_filename or None

0 comments on commit 66a7f56

Please sign in to comment.