Skip to content

Commit

Permalink
Merge pull request #447 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Expose new models for tuning, add get_class_full_path(), and test visual funcs
  • Loading branch information
WenjieDu authored Jun 23, 2024
2 parents 8956ac5 + b3c419a commit ca30aab
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 18 deletions.
42 changes: 24 additions & 18 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
BRITS,
GRUD,
Transformer,
TiDE,
Reformer,
RevIN_SCINet,
)
from ..optim import Adam
from ..utils.logging import logger
Expand All @@ -58,34 +61,37 @@
)

NN_MODELS = {
# imputation models
"pypots.imputation.SAITS": SAITS,
"pypots.imputation.iTransformer": iTransformer,
"pypots.imputation.Transformer": Transformer,
"pypots.imputation.FreTS": FreTS,
"pypots.imputation.Koopa": Koopa,
# imputation models, sorted by the first letter of the model name
"pypots.imputation.Autoformer": Autoformer,
"pypots.imputation.BRITS": BRITS,
"pypots.imputation.CSDI": CSDI,
"pypots.imputation.Crossformer": Crossformer,
"pypots.imputation.PatchTST": PatchTST,
"pypots.imputation.DLinear": DLinear,
"pypots.imputation.ETSformer": ETSformer,
"pypots.imputation.FreTS": FreTS,
"pypots.imputation.FiLM": FiLM,
"pypots.imputation.GPVAE": GPVAE,
"pypots.imputation.GRUD": GRUD,
"pypots.imputation.Informer": Informer,
"pypots.imputation.iTransformer": iTransformer,
"pypots.imputation.Koopa": Koopa,
"pypots.imputation.MICN": MICN,
"pypots.imputation.DLinear": DLinear,
"pypots.imputation.SCINet": SCINet,
"pypots.imputation.MRNN": MRNN,
"pypots.imputation.NonstationaryTransformer": NonstationaryTransformer,
"pypots.imputation.FiLM": FiLM,
"pypots.imputation.PatchTST": PatchTST,
"pypots.imputation.Pyraformer": Pyraformer,
"pypots.imputation.Autoformer": Autoformer,
"pypots.imputation.Informer": Informer,
"pypots.imputation.Reformer": Reformer,
"pypots.imputation.RevIN_SCINet": RevIN_SCINet,
"pypots.imputation.SAITS": SAITS,
"pypots.imputation.SCINet": SCINet,
"pypots.imputation.StemGNN": StemGNN,
"pypots.imputation.TimesNet": TimesNet,
"pypots.imputation.CSDI": CSDI,
"pypots.imputation.TiDE": TiDE,
"pypots.imputation.Transformer": Transformer,
"pypots.imputation.USGAN": USGAN,
"pypots.imputation.GPVAE": GPVAE,
"pypots.imputation.BRITS": BRITS,
"pypots.imputation.MRNN": MRNN,
"pypots.imputation.GRUD": GRUD,
# classification models
"pypots.classification.GRUD": GRUD_classification,
"pypots.classification.BRITS": BRITS_classification,
"pypots.classification.GRUD": GRUD_classification,
"pypots.classification.Raindrop": Raindrop,
# clustering models
"pypots.clustering.CRLI": CRLI,
Expand Down
21 changes: 21 additions & 0 deletions pypots/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,24 @@ def create_dir_if_not_exist(path: str, is_dir: bool = True) -> None:
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
logger.info(f"Successfully created the given path {path}")


def get_class_full_path(cls) -> str:
"""Get the full path of the given class.
Parameters
----------
cls:
The class to get the full path.
Returns
-------
path :
The full path of the given class.
"""
module = cls.__module__
path = cls.__qualname__
if module is not None and module != "__builtin__":
path = module + "." + path
return path
7 changes: 7 additions & 0 deletions tests/imputation/saits.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_mse
from pypots.utils.visual.data import plot_data, plot_missingness
from tests.global_test_config import (
DATA,
EPOCHS,
Expand Down Expand Up @@ -79,6 +80,12 @@ def test_1_impute(self):
)
logger.info(f"SAITS test_MSE: {test_MSE}")

# plot the missingness and imputed data
plot_missingness(
~np.isnan(TEST_SET["X"]), 0, imputation_results["imputation"].shape[1]
)
plot_data(TEST_SET["X"], TEST_SET["X_ori"], imputation_results["imputation"])

@pytest.mark.xdist_group(name="imputation-saits")
def test_2_parameters(self):
assert hasattr(self.saits, "model") and self.saits.model is not None
Expand Down

0 comments on commit ca30aab

Please sign in to comment.