From dee9acc25ef5f090082188ade5f6503cfa6cf753 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 8 Dec 2023 22:16:15 +0800 Subject: [PATCH 1/3] feat: add get_random_seed(); --- pypots/utils/random.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pypots/utils/random.py b/pypots/utils/random.py index b3654752..9ec3af57 100644 --- a/pypots/utils/random.py +++ b/pypots/utils/random.py @@ -1,5 +1,5 @@ """ -Transformer model for time-series imputation. +PyPOTS util module about random seed setting. """ # Created by Wenjie Du @@ -7,6 +7,7 @@ import numpy as np import torch + from .logging import logger RANDOM_SEED = 2204 @@ -21,7 +22,19 @@ def set_random_seed(random_seed: int = RANDOM_SEED) -> None: The seed to be set for generating random numbers in PyPOTS. """ - - np.random.seed(RANDOM_SEED) + globals()["RANDOM_SEED"] = random_seed + np.random.seed(random_seed) torch.manual_seed(random_seed) logger.info(f"Have set the random seed as {random_seed} for numpy and pytorch.") + + +def get_random_seed() -> int: + """Get the random seed used in PyPOTS. + + Returns + ------- + random_seed : + The random seed used in PyPOTS. + + """ + return RANDOM_SEED From 8cd435f98c7726d1aa3f3517819d11dd24f5991b Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 9 Dec 2023 00:33:57 +0800 Subject: [PATCH 2/3] feat: add calc_quantile_crps() and rename metric func into calc_*; --- pypots/clustering/crli/modules/core.py | 6 +- pypots/clustering/vader/modules/core.py | 4 +- pypots/imputation/base.py | 4 +- pypots/imputation/brits/modules/core.py | 8 +- pypots/imputation/csdi/model.py | 4 +- pypots/imputation/mrnn/modules/core.py | 4 +- pypots/imputation/saits/model.py | 4 +- pypots/imputation/saits/modules/core.py | 4 +- pypots/imputation/transformer/modules/core.py | 6 +- pypots/utils/metrics/__init__.py | 65 ++++++++-- pypots/utils/metrics/classification.py | 55 ++++++-- pypots/utils/metrics/clustering.py | 90 ++++++++++--- pypots/utils/metrics/error.py | 119 +++++++++++++++--- tests/classification/brits.py | 4 +- tests/classification/grud.py | 4 +- tests/classification/raindrop.py | 4 +- tests/clustering/crli.py | 12 +- tests/clustering/vader.py | 8 +- tests/forecasting/bttf.py | 4 +- tests/imputation/brits.py | 4 +- tests/imputation/csdi.py | 9 +- tests/imputation/gpvae.py | 4 +- tests/imputation/locf.py | 14 +-- tests/imputation/mrnn.py | 4 +- tests/imputation/saits.py | 4 +- tests/imputation/timesnet.py | 4 +- tests/imputation/transformer.py | 4 +- tests/imputation/usgan.py | 4 +- tests/optim/adadelta.py | 4 +- tests/optim/adagrad.py | 4 +- tests/optim/adam.py | 4 +- tests/optim/adamw.py | 4 +- tests/optim/lr_schedulers.py | 16 +-- tests/optim/rmsprop.py | 4 +- tests/optim/sgd.py | 4 +- tests/utils/random.py | 12 +- 36 files changed, 377 insertions(+), 135 deletions(-) diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py index a4c16a2a..8cbd45b0 100644 --- a/pypots/clustering/crli/modules/core.py +++ b/pypots/clustering/crli/modules/core.py @@ -18,7 +18,7 @@ from sklearn.cluster import KMeans from .submodules import Generator, Decoder, Discriminator -from ....utils.metrics import cal_mse +from ....utils.metrics import calc_mse class _CRLI(nn.Module): @@ -89,8 +89,8 @@ def forward( l_G = F.binary_cross_entropy_with_logits( inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask ) - l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask) - l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) + l_pre = calc_mse(inputs["imputation_latent"], X, missing_mask) + l_rec = calc_mse(inputs["reconstruction"], X, missing_mask) HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0)) if ( diff --git a/pypots/clustering/vader/modules/core.py b/pypots/clustering/vader/modules/core.py index 8ff2f4ac..41c5019f 100644 --- a/pypots/clustering/vader/modules/core.py +++ b/pypots/clustering/vader/modules/core.py @@ -21,7 +21,7 @@ PeepholeLSTMCell, ImplicitImputation, ) -from ....utils.metrics import cal_mse +from ....utils.metrics import calc_mse class _VaDER(nn.Module): @@ -184,7 +184,7 @@ def forward( } # calculate the reconstruction loss - unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask) + unscaled_reconstruction_loss = calc_mse(X_reconstructed, X, missing_mask) reconstruction_loss = ( unscaled_reconstruction_loss * self.n_steps diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 27e2c063..488d5d6f 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -16,7 +16,7 @@ from ..base import BaseModel, BaseNNModel from ..utils.logging import logger -from ..utils.metrics import cal_mse +from ..utils.metrics import calc_mse try: import nni @@ -299,7 +299,7 @@ def _train_model( imputation_collector = torch.cat(imputation_collector) imputation_collector = imputation_collector.cpu().detach().numpy() - mean_val_loss = cal_mse( + mean_val_loss = calc_mse( imputation_collector, val_loader.dataset.data["X_intact"], val_loader.dataset.data["indicating_mask"], diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py index 83b48f95..689d5582 100644 --- a/pypots/imputation/brits/modules/core.py +++ b/pypots/imputation/brits/modules/core.py @@ -21,7 +21,7 @@ from .submodules import FeatureRegression from ....modules.rnn import TemporalDecay -from ....utils.metrics import cal_mae +from ....utils.metrics import calc_mae class RITS(nn.Module): @@ -150,17 +150,17 @@ def impute( hidden_states = hidden_states * gamma_h # decay hidden states x_h = self.hist_reg(hidden_states) - reconstruction_loss += cal_mae(x_h, x, m) + reconstruction_loss += calc_mae(x_h, x, m) x_c = m * x + (1 - m) * x_h z_h = self.feat_reg(x_c) - reconstruction_loss += cal_mae(z_h, x, m) + reconstruction_loss += calc_mae(z_h, x, m) alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1))) c_h = alpha * z_h + (1 - alpha) * x_h - reconstruction_loss += cal_mae(c_h, x, m) + reconstruction_loss += calc_mae(c_h, x, m) c_c = m * x + (1 - m) * c_h estimations.append(c_h.unsqueeze(dim=1)) diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 71b76918..25a455fb 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -32,7 +32,7 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mse +from ...utils.metrics import calc_mse class CSDI(BaseNNImputer): @@ -256,7 +256,7 @@ def _train_model( imputation_collector = torch.cat(imputation_collector) imputation_collector = imputation_collector.cpu().detach().numpy() - mean_val_loss = cal_mse( + mean_val_loss = calc_mse( imputation_collector, val_loader.dataset.data["X_intact"], val_loader.dataset.data["indicating_mask"], diff --git a/pypots/imputation/mrnn/modules/core.py b/pypots/imputation/mrnn/modules/core.py index e4936ec8..ba0ba2cf 100644 --- a/pypots/imputation/mrnn/modules/core.py +++ b/pypots/imputation/mrnn/modules/core.py @@ -12,7 +12,7 @@ import torch.nn as nn from .submodules import FCN_Regression -from ....utils.metrics import cal_rmse +from ....utils.metrics import calc_rmse class _MRNN(nn.Module): @@ -74,7 +74,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: FCN_estimation = self.fcn_regression( x, m, RNN_imputed_data ) # FCN estimation is output estimation - reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( + reconstruction_loss += calc_rmse(FCN_estimation, x, m) + calc_rmse( RNN_estimation, x, m ) estimations.append(FCN_estimation.unsqueeze(dim=1)) diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 378062a4..99952645 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -27,7 +27,7 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mae +from ...utils.metrics import calc_mae class SAITS(BaseNNImputer): @@ -145,7 +145,7 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, - customized_loss_func: Callable = cal_mae, + customized_loss_func: Callable = calc_mae, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py index eb062709..51c2dfd5 100644 --- a/pypots/imputation/saits/modules/core.py +++ b/pypots/imputation/saits/modules/core.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ....modules.transformer import EncoderLayer, PositionalEncoding -from ....utils.metrics import cal_mae +from ....utils.metrics import calc_mae class _SAITS(nn.Module): @@ -39,7 +39,7 @@ def __init__( diagonal_attention_mask: bool = True, ORT_weight: float = 1, MIT_weight: float = 1, - customized_loss_func: Callable = cal_mae, + customized_loss_func: Callable = calc_mae, ): super().__init__() self.n_layers = n_layers diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py index 34750da8..7044358d 100644 --- a/pypots/imputation/transformer/modules/core.py +++ b/pypots/imputation/transformer/modules/core.py @@ -19,7 +19,7 @@ import torch.nn as nn from ....modules.transformer import EncoderLayer, PositionalEncoding -from ....utils.metrics import cal_mae +from ....utils.metrics import calc_mae class _TransformerEncoder(nn.Module): @@ -89,8 +89,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: - ORT_loss = cal_mae(learned_presentation, X, masks) - MIT_loss = cal_mae( + ORT_loss = calc_mae(learned_presentation, X, masks) + MIT_loss = calc_mae( learned_presentation, inputs["X_intact"], inputs["indicating_mask"] ) results["ORT_loss"] = ORT_loss diff --git a/pypots/utils/metrics/__init__.py b/pypots/utils/metrics/__init__.py index 9856a49b..22e155de 100644 --- a/pypots/utils/metrics/__init__.py +++ b/pypots/utils/metrics/__init__.py @@ -6,6 +6,12 @@ # License: BSD-3-Clause from .classification import ( + calc_binary_classification_metrics, + calc_precision_recall_f1, + calc_pr_auc, + calc_roc_auc, + calc_acc, + # deprecated cal_binary_classification_metrics, cal_precision_recall_f1, cal_pr_auc, @@ -13,6 +19,16 @@ cal_acc, ) from .clustering import ( + calc_rand_index, + calc_adjusted_rand_index, + calc_cluster_purity, + calc_nmi, + calc_chs, + calc_dbs, + calc_silhouette, + calc_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, + # deprecated cal_rand_index, cal_adjusted_rand_index, cal_cluster_purity, @@ -23,21 +39,49 @@ cal_internal_cluster_validation_metrics, cal_external_cluster_validation_metrics, ) -from .error import cal_mae, cal_mse, cal_rmse, cal_mre +from .error import ( + calc_mae, + calc_mse, + calc_rmse, + calc_mre, + calc_quantile_crps, + calc_quantile_crps_sum, + # deprecated + cal_mae, + cal_mse, + cal_rmse, + cal_mre, +) __all__ = [ # error + "calc_mae", + "calc_mse", + "calc_rmse", + "calc_mre", + "calc_quantile_crps", + "calc_quantile_crps_sum", + # classification + "calc_binary_classification_metrics", + "calc_precision_recall_f1", + "calc_pr_auc", + "calc_roc_auc", + "calc_acc", + # clustering + "calc_rand_index", + "calc_adjusted_rand_index", + "calc_cluster_purity", + "calc_nmi", + "calc_chs", + "calc_dbs", + "calc_silhouette", + "calc_internal_cluster_validation_metrics", + "calc_external_cluster_validation_metrics", + # deprecated "cal_mae", "cal_mse", "cal_rmse", "cal_mre", - # classification - "cal_binary_classification_metrics", - "cal_precision_recall_f1", - "cal_pr_auc", - "cal_roc_auc", - "cal_acc", - # clustering "cal_rand_index", "cal_adjusted_rand_index", "cal_cluster_purity", @@ -47,4 +91,9 @@ "cal_silhouette", "cal_internal_cluster_validation_metrics", "cal_external_cluster_validation_metrics", + "cal_binary_classification_metrics", + "cal_precision_recall_f1", + "cal_pr_auc", + "cal_roc_auc", + "cal_acc", ] diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py index 8d45e22f..d9867f6f 100644 --- a/pypots/utils/metrics/classification.py +++ b/pypots/utils/metrics/classification.py @@ -10,8 +10,10 @@ import numpy as np from sklearn import metrics +from ..logging import logger -def cal_binary_classification_metrics( + +def calc_binary_classification_metrics( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -85,7 +87,7 @@ def cal_binary_classification_metrics( raise f"predictions dimensions should be 1 or 2, but got predictions.shape: {prob_predictions.shape}" # accuracy score doesn't have to be of binary classification - acc_score = cal_acc(prediction_categories, targets) + acc_score = calc_acc(prediction_categories, targets) # turn targets into binary targets mask_val = -1 if pos_label == 0 else 0 @@ -93,13 +95,13 @@ def cal_binary_classification_metrics( binary_targets = np.copy(targets) binary_targets[~mask] = mask_val - precision, recall, f1 = cal_precision_recall_f1( + precision, recall, f1 = calc_precision_recall_f1( binary_prediction_categories, binary_targets, pos_label ) - pr_auc, precisions, recalls, _ = cal_pr_auc( + pr_auc, precisions, recalls, _ = calc_pr_auc( binary_predictions, binary_targets, pos_label ) - ROC_AUC, fprs, tprs, _ = cal_roc_auc(binary_predictions, binary_targets, pos_label) + ROC_AUC, fprs, tprs, _ = calc_roc_auc(binary_predictions, binary_targets, pos_label) PR_AUC = metrics.auc(recalls, precisions) classification_metrics = { "predictions": prediction_categories, @@ -117,7 +119,7 @@ def cal_binary_classification_metrics( return classification_metrics -def cal_precision_recall_f1( +def calc_precision_recall_f1( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -154,7 +156,7 @@ def cal_precision_recall_f1( return precision, recall, f1 -def cal_pr_auc( +def calc_pr_auc( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -195,7 +197,7 @@ def cal_pr_auc( return pr_auc, precisions, recalls, thresholds -def cal_roc_auc( +def calc_roc_auc( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -235,7 +237,7 @@ def cal_roc_auc( return roc_auc, fprs, tprs, thresholds -def cal_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: +def calc_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: """Calculate accuracy score of model predictions. Parameters @@ -254,3 +256,38 @@ def cal_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: """ acc_score = metrics.accuracy_score(targets, class_predictions) return acc_score + + +######################################################################################################################## +# Deprecated functions +######################################################################################################################## + + +def cal_binary_classification_metrics(**kwargs): + logger.warning( + "🚨 cal_binary_classification_metrics() is deprecated, " + "use calc_binary_classification_metrics() instead." + ) + return calc_binary_classification_metrics(**kwargs) + + +def cal_precision_recall_f1(**kwargs): + logger.warning( + "🚨 cal_precision_recall_f1() is deprecated, use calc_precision_recall_f1() instead." + ) + return calc_precision_recall_f1(**kwargs) + + +def cal_pr_auc(**kwargs): + logger.warning("🚨 cal_pr_auc() is deprecated, use calc_pr_auc() instead.") + return calc_pr_auc(**kwargs) + + +def cal_roc_auc(**kwargs): + logger.warning("🚨 cal_roc_auc() is deprecated, use calc_roc_auc() instead.") + return calc_roc_auc(**kwargs) + + +def cal_acc(**kwargs): + logger.warning("🚨 cal_acc() is deprecated, use calc_acc() instead.") + return calc_acc(**kwargs) diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py index bf05195d..f417d295 100644 --- a/pypots/utils/metrics/clustering.py +++ b/pypots/utils/metrics/clustering.py @@ -8,8 +8,10 @@ import numpy as np from sklearn import metrics +from ..logging import logger -def cal_rand_index( + +def calc_rand_index( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -60,7 +62,7 @@ def cal_rand_index( return RI -def cal_adjusted_rand_index( +def calc_adjusted_rand_index( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -96,7 +98,7 @@ def cal_adjusted_rand_index( return aRI -def cal_nmi( +def calc_nmi( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -121,7 +123,7 @@ def cal_nmi( return NMI -def cal_cluster_purity( +def calc_cluster_purity( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -152,7 +154,7 @@ def cal_cluster_purity( return cluster_purity -def cal_external_cluster_validation_metrics(class_predictions, targets): +def calc_external_cluster_validation_metrics(class_predictions, targets): """Computer all external cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -169,10 +171,10 @@ def cal_external_cluster_validation_metrics(class_predictions, targets): A dictionary contains all external cluster validation metrics available in PyPOTS. """ - ri = cal_rand_index(class_predictions, targets) - ari = cal_adjusted_rand_index(class_predictions, targets) - nmi = cal_nmi(class_predictions, targets) - cp = cal_cluster_purity(class_predictions, targets) + ri = calc_rand_index(class_predictions, targets) + ari = calc_adjusted_rand_index(class_predictions, targets) + nmi = calc_nmi(class_predictions, targets) + cp = calc_cluster_purity(class_predictions, targets) external_cluster_validation_metrics = { "rand_index": ri, @@ -183,7 +185,7 @@ def cal_external_cluster_validation_metrics(class_predictions, targets): return external_cluster_validation_metrics -def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: +def calc_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the mean Silhouette Coefficient of all samples. Parameters @@ -214,7 +216,7 @@ def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: return silhouette_score -def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: +def calc_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the Calinski and Harabasz score (also known as the Variance Ratio Criterion). X : array-like of shape (n_samples_a, n_features) @@ -239,7 +241,7 @@ def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: return calinski_harabasz_score -def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: +def calc_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the Davies-Bouldin score. Parameters @@ -268,7 +270,7 @@ def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: return davies_bouldin_score -def cal_internal_cluster_validation_metrics(X, predicted_labels): +def calc_internal_cluster_validation_metrics(X, predicted_labels): """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -285,9 +287,9 @@ def cal_internal_cluster_validation_metrics(X, predicted_labels): A dictionary contains all internal cluster validation metrics available in PyPOTS. """ - silhouette_score = cal_silhouette(X, predicted_labels) - calinski_harabasz_score = cal_chs(X, predicted_labels) - davies_bouldin_score = cal_dbs(X, predicted_labels) + silhouette_score = calc_silhouette(X, predicted_labels) + calinski_harabasz_score = calc_chs(X, predicted_labels) + davies_bouldin_score = calc_dbs(X, predicted_labels) internal_cluster_validation_metrics = { "silhouette_score": silhouette_score, @@ -295,3 +297,59 @@ def cal_internal_cluster_validation_metrics(X, predicted_labels): "davies_bouldin_score": davies_bouldin_score, } return internal_cluster_validation_metrics + + +######################################################################################################################## +# Deprecated functions +######################################################################################################################## + + +def cal_rand_index(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_rand_index` instead.") + return calc_rand_index(**kwargs) + + +def cal_adjusted_rand_index(**kwargs): + logger.warning( + "🚨 Deprecated function, please use `calc_adjusted_rand_index` instead." + ) + return calc_adjusted_rand_index(**kwargs) + + +def cal_nmi(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_nmi` instead.") + return calc_nmi(**kwargs) + + +def cal_cluster_purity(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_cluster_purity` instead.") + return calc_cluster_purity(**kwargs) + + +def cal_external_cluster_validation_metrics(**kwargs): + logger.warning( + "🚨 Deprecated function, please use `calc_external_cluster_validation_metrics` instead." + ) + return calc_external_cluster_validation_metrics(**kwargs) + + +def cal_silhouette(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_silhouette` instead.") + return calc_silhouette(**kwargs) + + +def cal_chs(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_chs` instead.") + return calc_chs(**kwargs) + + +def cal_dbs(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_dbs` instead.") + return calc_dbs(**kwargs) + + +def cal_internal_cluster_validation_metrics(**kwargs): + logger.warning( + "🚨 Deprecated function, please use `calc_internal_cluster_validation_metrics` instead." + ) + return calc_internal_cluster_validation_metrics(**kwargs) diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py index fc8133d5..8cafcb61 100644 --- a/pypots/utils/metrics/error.py +++ b/pypots/utils/metrics/error.py @@ -10,8 +10,10 @@ import numpy as np import torch +from ..logging import logger -def cal_mae( + +def calc_mae( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -36,10 +38,10 @@ def cal_mae( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_mae + >>> from pypots.utils.metrics import calc_mae >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mae = cal_mae(predictions, targets) + >>> mae = calc_mae(predictions, targets) mae = 0.6 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, so the result is 3/5=0.6. @@ -47,7 +49,7 @@ def cal_mae( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mae = cal_mae(predictions, targets, masks) + >>> mae = calc_mae(predictions, targets, masks) mae = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|=1`, so the result is 1/2=0.5. @@ -66,7 +68,7 @@ def cal_mae( return lib.mean(lib.abs(predictions - targets)) -def cal_mse( +def calc_mse( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -91,10 +93,10 @@ def cal_mse( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_mse + >>> from pypots.utils.metrics import calc_mse >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mse = cal_mse(predictions, targets) + >>> mse = calc_mse(predictions, targets) mse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, so the result is 5/5=1. @@ -102,7 +104,7 @@ def cal_mse( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mse = cal_mse(predictions, targets, masks) + >>> mse = calc_mse(predictions, targets, masks) mse = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, so the result is 1/2=0.5. @@ -122,7 +124,7 @@ def cal_mse( return lib.mean(lib.square(predictions - targets)) -def cal_rmse( +def calc_rmse( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -147,10 +149,10 @@ def cal_rmse( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_rmse + >>> from pypots.utils.metrics import calc_rmse >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> rmse = cal_rmse(predictions, targets) + >>> rmse = calc_rmse(predictions, targets) rmse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, so the result is :math:`\\sqrt{5/5}=1`. @@ -159,7 +161,7 @@ def cal_rmse( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> rmse = cal_rmse(predictions, targets, masks) + >>> rmse = calc_rmse(predictions, targets, masks) rmse = 0.707 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, so the result is :math:`\\sqrt{1/2}=0.5`. @@ -170,10 +172,10 @@ def cal_rmse( f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) lib = np if isinstance(predictions, np.ndarray) else torch - return lib.sqrt(cal_mse(predictions, targets, masks)) + return lib.sqrt(calc_mse(predictions, targets, masks)) -def cal_mre( +def calc_mre( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -198,10 +200,10 @@ def cal_mre( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_mre + >>> from pypots.utils.metrics import calc_mre >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mre = cal_mre(predictions, targets) + >>> mre = calc_mre(predictions, targets) mre = 0.2 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, so the result is :math:`\\sqrt{3/(1+2+3+4+5)}=1`. @@ -210,7 +212,7 @@ def cal_mre( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mre = cal_mre(predictions, targets, masks) + >>> mre = calc_mre(predictions, targets, masks) mre = 0.111 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, so the result is :math:`\\sqrt{1/2}=0.5`. @@ -229,3 +231,86 @@ def cal_mre( return lib.sum(lib.abs(predictions - targets)) / ( lib.sum(lib.abs(targets)) + 1e-12 ) + + +def calc_quantile_loss(predictions, targets, q: float, eval_points) -> float: + quantile_loss = 2 * torch.sum( + torch.abs( + (predictions - targets) * eval_points * ((targets <= predictions) * 1.0 - q) + ) + ) + return quantile_loss + + +def calc_quantile_crps(predictions, targets, eval_points, mean_scaler=0, scaler=1): + """Continuous rank probability score for distributional predictions.""" + if isinstance(predictions, np.ndarray): + predictions = torch.from_numpy(predictions) + if isinstance(targets, np.ndarray): + targets = torch.from_numpy(targets) + if isinstance(eval_points, np.ndarray): + eval_points = torch.from_numpy(eval_points) + + targets = targets * scaler + mean_scaler + predictions = predictions * scaler + mean_scaler + + quantiles = np.arange(0.05, 1.0, 0.05) + denominator = torch.sum(torch.abs(targets * eval_points)) + CRPS = 0 + for i in range(len(quantiles)): + q_pred = [] + for j in range(len(predictions)): + q_pred.append(torch.quantile(predictions[j : j + 1], quantiles[i], dim=1)) + q_pred = torch.cat(q_pred, 0) + q_loss = calc_quantile_loss(targets, q_pred, quantiles[i], eval_points) + CRPS += q_loss / denominator + return CRPS.item() / len(quantiles) + + +def calc_quantile_crps_sum(predictions, targets, eval_points, mean_scaler=0, scaler=1): + """Continuous rank probability score for distributional predictions.""" + if isinstance(predictions, np.ndarray): + predictions = torch.from_numpy(predictions) + if isinstance(targets, np.ndarray): + targets = torch.from_numpy(targets) + if isinstance(eval_points, np.ndarray): + eval_points = torch.from_numpy(eval_points) + + eval_points = eval_points.mean(-1) + targets = targets * scaler + mean_scaler + targets = targets.sum(-1) + predictions = predictions * scaler + mean_scaler + + quantiles = np.arange(0.05, 1.0, 0.05) + denominator = torch.sum(torch.abs(targets * eval_points)) + CRPS = 0 + for i in range(len(quantiles)): + q_pred = torch.quantile(predictions.sum(-1), quantiles[i], dim=1) + q_loss = calc_quantile_loss(targets, q_pred, quantiles[i], eval_points) + CRPS += q_loss / denominator + return CRPS.item() / len(quantiles) + + +######################################################################################################################## +# Deprecated functions +######################################################################################################################## + + +def cal_mae(**kwargs): + logger.warning("🚨 cal_mae() is deprecated, use calc_mae() instead.") + return calc_mae(**kwargs) + + +def cal_rmse(**kwargs): + logger.warning("🚨 cal_rmse() is deprecated, use calc_rmse() instead.") + return calc_rmse(**kwargs) + + +def cal_mse(**kwargs): + logger.warning("🚨 cal_mse() is deprecated, use calc_mse() instead.") + return calc_mse(**kwargs) + + +def cal_mre(**kwargs): + logger.warning("🚨 cal_mre() is deprecated, use calc_mre() instead.") + return calc_mre(**kwargs) diff --git a/tests/classification/brits.py b/tests/classification/brits.py index 78e8c042..c7815f5c 100644 --- a/tests/classification/brits.py +++ b/tests/classification/brits.py @@ -13,7 +13,7 @@ from pypots.classification import BRITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics +from pypots.utils.metrics import calc_binary_classification_metrics from tests.classification.config import ( EPOCHS, TRAIN_SET, @@ -58,7 +58,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="classification-brits") def test_1_classify(self): predictions = self.brits.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + metrics = calc_binary_classification_metrics(predictions, DATA["test_y"]) logger.info( f'ROC_AUC: {metrics["roc_auc"]}, \n' f'PR_AUC: {metrics["pr_auc"]},\n' diff --git a/tests/classification/grud.py b/tests/classification/grud.py index abc94f63..37bad931 100644 --- a/tests/classification/grud.py +++ b/tests/classification/grud.py @@ -13,7 +13,7 @@ from pypots.classification import GRUD from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics +from pypots.utils.metrics import calc_binary_classification_metrics from tests.classification.config import ( EPOCHS, TRAIN_SET, @@ -57,7 +57,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="classification-grud") def test_1_classify(self): predictions = self.grud.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + metrics = calc_binary_classification_metrics(predictions, DATA["test_y"]) logger.info( f'ROC_AUC: {metrics["roc_auc"]}, \n' f'PR_AUC: {metrics["pr_auc"]},\n' diff --git a/tests/classification/raindrop.py b/tests/classification/raindrop.py index a7c42234..967f73ec 100644 --- a/tests/classification/raindrop.py +++ b/tests/classification/raindrop.py @@ -12,7 +12,7 @@ from pypots.classification import Raindrop from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics +from pypots.utils.metrics import calc_binary_classification_metrics from tests.classification.config import ( EPOCHS, TRAIN_SET, @@ -60,7 +60,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="classification-raindrop") def test_1_classify(self): predictions = self.raindrop.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + metrics = calc_binary_classification_metrics(predictions, DATA["test_y"]) logger.info( f'ROC_AUC: {metrics["roc_auc"]}, \n' f'PR_AUC: {metrics["pr_auc"]},\n' diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 2b849b79..63960619 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -15,8 +15,8 @@ from pypots.optim import Adam from pypots.utils.logging import logger from pypots.utils.metrics import ( - cal_external_cluster_validation_metrics, - cal_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, + calc_internal_cluster_validation_metrics, ) from tests.clustering.config import ( EPOCHS, @@ -125,10 +125,10 @@ def test_1_parameters(self): def test_2_cluster(self): # GRU cell clustering_results = self.crli_gru.predict(TEST_SET, return_latent_vars=True) - external_metrics = cal_external_cluster_validation_metrics( + external_metrics = calc_external_cluster_validation_metrics( clustering_results["clustering"], DATA["test_y"] ) - internal_metrics = cal_internal_cluster_validation_metrics( + internal_metrics = calc_internal_cluster_validation_metrics( clustering_results["latent_vars"]["clustering_latent"], DATA["test_y"] ) logger.info(f"CRLI-GRU: {external_metrics}") @@ -136,10 +136,10 @@ def test_2_cluster(self): # LSTM cell clustering_results = self.crli_lstm.predict(TEST_SET, return_latent_vars=True) - external_metrics = cal_external_cluster_validation_metrics( + external_metrics = calc_external_cluster_validation_metrics( clustering_results["clustering"], DATA["test_y"] ) - internal_metrics = cal_internal_cluster_validation_metrics( + internal_metrics = calc_internal_cluster_validation_metrics( clustering_results["latent_vars"]["clustering_latent"], DATA["test_y"] ) logger.info(f"CRLI-LSTM: {external_metrics}") diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index cbdae092..d5143367 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -16,8 +16,8 @@ from pypots.optim import Adam from pypots.utils.logging import logger from pypots.utils.metrics import ( - cal_external_cluster_validation_metrics, - cal_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, + calc_internal_cluster_validation_metrics, ) from tests.clustering.config import ( EPOCHS, @@ -65,10 +65,10 @@ def test_0_fit(self): def test_1_cluster(self): try: clustering_results = self.vader.predict(TEST_SET, return_latent_vars=True) - external_metrics = cal_external_cluster_validation_metrics( + external_metrics = calc_external_cluster_validation_metrics( clustering_results["clustering"], DATA["test_y"] ) - internal_metrics = cal_internal_cluster_validation_metrics( + internal_metrics = calc_internal_cluster_validation_metrics( clustering_results["latent_vars"]["z"], DATA["test_y"] ) logger.info(f"{external_metrics}") diff --git a/tests/forecasting/bttf.py b/tests/forecasting/bttf.py index 1ced03b0..1483e7d7 100644 --- a/tests/forecasting/bttf.py +++ b/tests/forecasting/bttf.py @@ -11,7 +11,7 @@ from pypots.forecasting import BTTF from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.forecasting.config import ( TEST_SET, TEST_SET_INTACT, @@ -38,7 +38,7 @@ class TestBTTF(unittest.TestCase): @pytest.mark.xdist_group(name="forecasting-bttf") def test_0_forecasting(self): predictions = self.bttf.forecast(TEST_SET) - mae = cal_mae(predictions, TEST_SET_INTACT["X"][:, -N_PRED_STEP:]) + mae = calc_mae(predictions, TEST_SET_INTACT["X"][:, -N_PRED_STEP:]) logger.info(f"prediction MAE: {mae}") diff --git a/tests/imputation/brits.py b/tests/imputation/brits.py index 69ea9613..e5eb2cb7 100644 --- a/tests/imputation/brits.py +++ b/tests/imputation/brits.py @@ -15,7 +15,7 @@ from pypots.imputation import BRITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -61,7 +61,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"BRITS test_MAE: {test_MAE}") diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py index a5d3a73f..0ccf1222 100644 --- a/tests/imputation/csdi.py +++ b/tests/imputation/csdi.py @@ -15,7 +15,7 @@ from pypots.imputation import CSDI from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae, calc_quantile_crps from tests.global_test_config import ( DATA, DEVICE, @@ -63,14 +63,17 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="imputation-csdi") def test_1_impute(self): imputed_X = self.csdi.predict(TEST_SET)["imputation"] + test_CRPS = calc_quantile_crps( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) imputed_X = imputed_X.mean(axis=1) # mean over sampling times assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) - logger.info(f"CSDI test_MAE: {test_MAE}") + logger.info(f"CSDI test_MAE: {test_MAE}, test_CRPS: {test_CRPS}") @pytest.mark.xdist_group(name="imputation-csdi") def test_2_parameters(self): diff --git a/tests/imputation/gpvae.py b/tests/imputation/gpvae.py index d2e45f31..b94bff37 100644 --- a/tests/imputation/gpvae.py +++ b/tests/imputation/gpvae.py @@ -15,7 +15,7 @@ from pypots.imputation import GPVAE from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -61,7 +61,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"GP-VAE test_MAE: {test_MAE}") diff --git a/tests/imputation/locf.py b/tests/imputation/locf.py index 18f7ed68..b43b7414 100644 --- a/tests/imputation/locf.py +++ b/tests/imputation/locf.py @@ -14,7 +14,7 @@ from pypots.imputation import LOCF from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, ) @@ -37,7 +37,7 @@ def test_0_impute(self): assert not np.isnan( test_X_imputed_zero ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_zero, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"LOCF (zero) test_MAE: {test_MAE}") @@ -46,7 +46,7 @@ def test_0_impute(self): assert not np.isnan( test_X_imputed_backward ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_backward, DATA["test_X_intact"], DATA["test_X_indicating_mask"], @@ -57,7 +57,7 @@ def test_0_impute(self): assert not np.isnan( test_X_imputed_mean ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_mean, DATA["test_X_intact"], DATA["test_X_indicating_mask"], @@ -80,14 +80,14 @@ def test_0_impute(self): assert not torch.isnan( test_X_imputed_zero ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae(test_X_imputed_zero, test_X_intact, test_X_indicating_mask) + test_MAE = calc_mae(test_X_imputed_zero, test_X_intact, test_X_indicating_mask) logger.info(f"LOCF (zero) test_MAE: {test_MAE}") test_X_imputed_backward = self.locf_backward.predict({"X": X})["imputation"] assert not torch.isnan( test_X_imputed_backward ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_backward, test_X_intact, test_X_indicating_mask, @@ -98,7 +98,7 @@ def test_0_impute(self): assert not torch.isnan( test_X_imputed_mean ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_mean, test_X_intact, test_X_indicating_mask, diff --git a/tests/imputation/mrnn.py b/tests/imputation/mrnn.py index ae28d0eb..b3074f09 100644 --- a/tests/imputation/mrnn.py +++ b/tests/imputation/mrnn.py @@ -15,7 +15,7 @@ from pypots.imputation import MRNN from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -61,7 +61,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"MRNN test_MAE: {test_MAE}") diff --git a/tests/imputation/saits.py b/tests/imputation/saits.py index a5620569..d25f9361 100644 --- a/tests/imputation/saits.py +++ b/tests/imputation/saits.py @@ -15,7 +15,7 @@ from pypots.imputation import SAITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -71,7 +71,7 @@ def test_1_impute(self): "latent_vars" in imputation_results.keys() ), "Latent variables are not returned thought `return_latent_vars` is set as True." - test_MAE = cal_mae( + test_MAE = calc_mae( imputation_results["imputation"], DATA["test_X_intact"], DATA["test_X_indicating_mask"], diff --git a/tests/imputation/timesnet.py b/tests/imputation/timesnet.py index 52e33ae4..33bfae3e 100644 --- a/tests/imputation/timesnet.py +++ b/tests/imputation/timesnet.py @@ -15,7 +15,7 @@ from pypots.imputation import TimesNet from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -67,7 +67,7 @@ def test_1_impute(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputation_results["imputation"], DATA["test_X_intact"], DATA["test_X_indicating_mask"], diff --git a/tests/imputation/transformer.py b/tests/imputation/transformer.py index c145ecfa..15624dc4 100644 --- a/tests/imputation/transformer.py +++ b/tests/imputation/transformer.py @@ -15,7 +15,7 @@ from pypots.imputation import Transformer from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -67,7 +67,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"Transformer test_MAE: {test_MAE}") diff --git a/tests/imputation/usgan.py b/tests/imputation/usgan.py index ea723238..0ff25ea3 100644 --- a/tests/imputation/usgan.py +++ b/tests/imputation/usgan.py @@ -15,7 +15,7 @@ from pypots.imputation import USGAN from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -63,7 +63,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"US-GAN test_MAE: {test_MAE}") diff --git a/tests/optim/adadelta.py b/tests/optim/adadelta.py index 71c991f2..c7eb6e6d 100644 --- a/tests/optim/adadelta.py +++ b/tests/optim/adadelta.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import Adadelta from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/adagrad.py b/tests/optim/adagrad.py index 6b055f21..7cb2a988 100644 --- a/tests/optim/adagrad.py +++ b/tests/optim/adagrad.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import Adagrad from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/adam.py b/tests/optim/adam.py index be6cb89b..9f583aee 100644 --- a/tests/optim/adam.py +++ b/tests/optim/adam.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/adamw.py b/tests/optim/adamw.py index e7f89797..e785e9f6 100644 --- a/tests/optim/adamw.py +++ b/tests/optim/adamw.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import AdamW from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/lr_schedulers.py b/tests/optim/lr_schedulers.py index 3a88976a..2aa1c520 100644 --- a/tests/optim/lr_schedulers.py +++ b/tests/optim/lr_schedulers.py @@ -22,7 +22,7 @@ MultiplicativeLR, ) from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -74,7 +74,7 @@ def test_0_lambda_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -102,7 +102,7 @@ def test_1_multiplicative_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -130,7 +130,7 @@ def test_2_step_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -158,7 +158,7 @@ def test_3_multistep_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -187,7 +187,7 @@ def test_4_constant_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -215,7 +215,7 @@ def test_5_linear_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -243,7 +243,7 @@ def test_6_exponential_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/rmsprop.py b/tests/optim/rmsprop.py index 29087520..f4a3f53c 100644 --- a/tests/optim/rmsprop.py +++ b/tests/optim/rmsprop.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import RMSprop from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/sgd.py b/tests/optim/sgd.py index 569bb96d..7dec3bf3 100644 --- a/tests/optim/sgd.py +++ b/tests/optim/sgd.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import SGD from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/utils/random.py b/tests/utils/random.py index 89d78d93..096b02d4 100644 --- a/tests/utils/random.py +++ b/tests/utils/random.py @@ -9,7 +9,7 @@ import torch -from pypots.utils.random import set_random_seed +from pypots.utils.random import set_random_seed, get_random_seed class TestRandom(unittest.TestCase): @@ -31,6 +31,16 @@ def test_set_random_seed(self): random_state1, random_state2 ), "The random seed has been set, two random states are not the same." + current_seed = get_random_seed() + assert ( + not current_seed == 32 + ), "The random seed has been set to 26, not equal to 32." + set_random_seed(32) + current_seed = get_random_seed() + assert ( + current_seed == 32 + ), "The random seed has been set to 32, should be equal." + if __name__ == "__main__": unittest.main() From 77f6794e1d12fdaa9eecb18a02f343c9a5aa65c5 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 9 Dec 2023 01:00:44 +0800 Subject: [PATCH 3/3] fix: error args passing in deprecated functions; --- pypots/imputation/timesnet/modules/core.py | 4 +- pypots/utils/metrics/classification.py | 20 +++++----- pypots/utils/metrics/clustering.py | 45 ++++++++++++---------- pypots/utils/metrics/error.py | 16 ++++---- 4 files changed, 45 insertions(+), 40 deletions(-) diff --git a/pypots/imputation/timesnet/modules/core.py b/pypots/imputation/timesnet/modules/core.py index 9dd4bf5a..ff51fe86 100644 --- a/pypots/imputation/timesnet/modules/core.py +++ b/pypots/imputation/timesnet/modules/core.py @@ -11,7 +11,7 @@ from .embedding import DataEmbedding from .layer import TimesBlock -from ....utils.metrics import cal_mse +from ....utils.metrics import calc_mse class _TimesNet(nn.Module): @@ -88,7 +88,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: if training: # `loss` is always the item for backward propagating to update the model - loss = cal_mse(dec_out, inputs["X_intact"], inputs["indicating_mask"]) + loss = calc_mse(dec_out, inputs["X_intact"], inputs["indicating_mask"]) results["loss"] = loss return results diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py index d9867f6f..ef74d218 100644 --- a/pypots/utils/metrics/classification.py +++ b/pypots/utils/metrics/classification.py @@ -263,31 +263,31 @@ def calc_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: ######################################################################################################################## -def cal_binary_classification_metrics(**kwargs): +def cal_binary_classification_metrics(*args): logger.warning( "🚨 cal_binary_classification_metrics() is deprecated, " "use calc_binary_classification_metrics() instead." ) - return calc_binary_classification_metrics(**kwargs) + return calc_binary_classification_metrics(*args) -def cal_precision_recall_f1(**kwargs): +def cal_precision_recall_f1(*args): logger.warning( "🚨 cal_precision_recall_f1() is deprecated, use calc_precision_recall_f1() instead." ) - return calc_precision_recall_f1(**kwargs) + return calc_precision_recall_f1(*args) -def cal_pr_auc(**kwargs): +def cal_pr_auc(*args): logger.warning("🚨 cal_pr_auc() is deprecated, use calc_pr_auc() instead.") - return calc_pr_auc(**kwargs) + return calc_pr_auc(*args) -def cal_roc_auc(**kwargs): +def cal_roc_auc(*args): logger.warning("🚨 cal_roc_auc() is deprecated, use calc_roc_auc() instead.") - return calc_roc_auc(**kwargs) + return calc_roc_auc(*args) -def cal_acc(**kwargs): +def cal_acc(*args): logger.warning("🚨 cal_acc() is deprecated, use calc_acc() instead.") - return calc_acc(**kwargs) + return calc_acc(*args) diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py index f417d295..2e24edd7 100644 --- a/pypots/utils/metrics/clustering.py +++ b/pypots/utils/metrics/clustering.py @@ -154,7 +154,10 @@ def calc_cluster_purity( return cluster_purity -def calc_external_cluster_validation_metrics(class_predictions, targets): +def calc_external_cluster_validation_metrics( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> dict: """Computer all external cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -270,7 +273,9 @@ def calc_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: return davies_bouldin_score -def calc_internal_cluster_validation_metrics(X, predicted_labels): +def calc_internal_cluster_validation_metrics( + X: np.ndarray, predicted_labels: np.ndarray +) -> dict: """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -304,52 +309,52 @@ def calc_internal_cluster_validation_metrics(X, predicted_labels): ######################################################################################################################## -def cal_rand_index(**kwargs): +def cal_rand_index(*args): logger.warning("🚨 Deprecated function, please use `calc_rand_index` instead.") - return calc_rand_index(**kwargs) + return calc_rand_index(*args) -def cal_adjusted_rand_index(**kwargs): +def cal_adjusted_rand_index(*args): logger.warning( "🚨 Deprecated function, please use `calc_adjusted_rand_index` instead." ) - return calc_adjusted_rand_index(**kwargs) + return calc_adjusted_rand_index(*args) -def cal_nmi(**kwargs): +def cal_nmi(*args): logger.warning("🚨 Deprecated function, please use `calc_nmi` instead.") - return calc_nmi(**kwargs) + return calc_nmi(*args) -def cal_cluster_purity(**kwargs): +def cal_cluster_purity(*args): logger.warning("🚨 Deprecated function, please use `calc_cluster_purity` instead.") - return calc_cluster_purity(**kwargs) + return calc_cluster_purity(*args) -def cal_external_cluster_validation_metrics(**kwargs): +def cal_external_cluster_validation_metrics(*args): logger.warning( "🚨 Deprecated function, please use `calc_external_cluster_validation_metrics` instead." ) - return calc_external_cluster_validation_metrics(**kwargs) + return calc_external_cluster_validation_metrics(*args) -def cal_silhouette(**kwargs): +def cal_silhouette(*args): logger.warning("🚨 Deprecated function, please use `calc_silhouette` instead.") - return calc_silhouette(**kwargs) + return calc_silhouette(*args) -def cal_chs(**kwargs): +def cal_chs(*args): logger.warning("🚨 Deprecated function, please use `calc_chs` instead.") - return calc_chs(**kwargs) + return calc_chs(*args) -def cal_dbs(**kwargs): +def cal_dbs(*args): logger.warning("🚨 Deprecated function, please use `calc_dbs` instead.") - return calc_dbs(**kwargs) + return calc_dbs(*args) -def cal_internal_cluster_validation_metrics(**kwargs): +def cal_internal_cluster_validation_metrics(*args): logger.warning( "🚨 Deprecated function, please use `calc_internal_cluster_validation_metrics` instead." ) - return calc_internal_cluster_validation_metrics(**kwargs) + return calc_internal_cluster_validation_metrics(*args) diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py index 8cafcb61..b8d10428 100644 --- a/pypots/utils/metrics/error.py +++ b/pypots/utils/metrics/error.py @@ -296,21 +296,21 @@ def calc_quantile_crps_sum(predictions, targets, eval_points, mean_scaler=0, sca ######################################################################################################################## -def cal_mae(**kwargs): +def cal_mae(*args): logger.warning("🚨 cal_mae() is deprecated, use calc_mae() instead.") - return calc_mae(**kwargs) + return calc_mae(*args) -def cal_rmse(**kwargs): +def cal_rmse(*args): logger.warning("🚨 cal_rmse() is deprecated, use calc_rmse() instead.") - return calc_rmse(**kwargs) + return calc_rmse(*args) -def cal_mse(**kwargs): +def cal_mse(*args): logger.warning("🚨 cal_mse() is deprecated, use calc_mse() instead.") - return calc_mse(**kwargs) + return calc_mse(*args) -def cal_mre(**kwargs): +def cal_mre(*args): logger.warning("🚨 cal_mre() is deprecated, use calc_mre() instead.") - return calc_mre(**kwargs) + return calc_mre(*args)