Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: adapt scdsc to BaseClusteringMethod class; create BasePretrain mixin #198

Merged
merged 13 commits into from
Feb 19, 2023
90 changes: 90 additions & 0 deletions dance/modules/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import os
from abc import ABC, abstractmethod, abstractstaticmethod
from contextlib import contextmanager
from functools import partialmethod
from operator import attrgetter
from time import time

import torch

from dance import logger
from dance.data import Data
from dance.transforms.base import BaseTransform
from dance.typing import Any, Mapping, Optional, Tuple, Union
Expand Down Expand Up @@ -47,6 +55,88 @@ def score(self, x, y, score_func: Optional[Union[str, Mapping[Any, float]]] = No
return (score, y_pred) if return_pred else score


class BasePretrain(ABC):

@property
def is_pretrained(self) -> bool:
return getattr(self, "_is_pretrained", False)

def _pretrain(self, *args, force_pretrain: bool = False, **kwargs):
pt_path = getattr(self, "pretrain_path", None)
if not force_pretrain:
if self.is_pretrained:
logger.info("Skipping pre_train as the model appears to be pretrained already. "
"If you wish to force pre-training, please set 'force_pretrain' to True.")
return

if pt_path is not None and os.path.isfile(pt_path):
logger.info(f"Loading pre-trained model from {pt_path}")
self.load_pretrained(pt_path)
self._is_pretrained = True
return

logger.info("Pre-training started")
if pt_path is None:
logger.warning("`pretrain_path` is not set, pre-trained model will not be saved.")
else:
logger.info(f"Pre-trained model will to saved to {pt_path}")

t = time()
self.pretrain(*args, **kwargs)
elapsed = time() - t
logger.info(f"Pre-training finished (took {elapsed:.2f} seconds)")
self._is_pretrained = True

if pt_path is not None:
logger.info(f"Saving pre-trained model to {pt_path}")
self.save_pretrained(pt_path)

def pretrain(self, *args, **kwargs):
...

def save_pretrained(self, path, **kwargs):
...

def load_pretrained(self, path, **kwargs):
...


class TorchNNPretrain(BasePretrain, ABC):

def _fix_unfix_modules(self, *module_names: Tuple[str], unfix: bool = False, single: bool = True):
modules = attrgetter(*module_names)(self)
modules = [modules] if single else modules

for module in modules:
for p in module.parameters():
p.requires_grad = unfix

fix_module = partialmethod(_fix_unfix_modules, unfix=False, single=True)
fix_modules = partialmethod(_fix_unfix_modules, unfix=False, single=False)
unfix_module = partialmethod(_fix_unfix_modules, unfix=True, single=True)
unfix_modules = partialmethod(_fix_unfix_modules, unfix=True, single=False)

@contextmanager
def pretrain_context(self, *module_names: Tuple[str]):
"""Unlock module for pretraining and lock once pretraining is done."""
is_single = len(module_names) == 1
logger.info(f"Entering pre-training context; unlocking: {module_names}")
self._fix_unfix_modules(*module_names, unfix=True, single=is_single)
try:
yield
finally:
logger.info(f"Exiting pre-training context; locking: {module_names}")
self._fix_unfix_modules(*module_names, unfix=False, single=is_single)

def save_pretrained(self, path):
torch.save(self.state_dict(), path)

def load_pretrained(self, path):
device = getattr(self, "device", None)
checkpoint = torch.load(path, map_location=device)
self.load_state_dict(checkpoint)


class BaseClassificationMethod(BaseMethod):

_DEFAULT_METRIC = "acc"
Expand Down
2 changes: 1 addition & 1 deletion dance/modules/single_modality/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .graphsc import GraphSC
from .scdcc import ScDCC
from .scdeepcluster import ScDeepCluster
from .scdsc import SCDSC
from .scdsc import ScDSCModel
from .sctag import ScTAG

__all__ = [
Expand Down
Loading