Skip to content

Commit

Permalink
Merge pull request #260 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Adding get_random_seed(), and adding func calc_quantile_crps()
  • Loading branch information
WenjieDu authored Dec 8, 2023
2 parents 610996f + 77f6794 commit d504e6b
Show file tree
Hide file tree
Showing 38 changed files with 400 additions and 140 deletions.
6 changes: 3 additions & 3 deletions pypots/clustering/crli/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.cluster import KMeans

from .submodules import Generator, Decoder, Discriminator
from ....utils.metrics import cal_mse
from ....utils.metrics import calc_mse


class _CRLI(nn.Module):
Expand Down Expand Up @@ -89,8 +89,8 @@ def forward(
l_G = F.binary_cross_entropy_with_logits(
inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask
)
l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask)
l_rec = cal_mse(inputs["reconstruction"], X, missing_mask)
l_pre = calc_mse(inputs["imputation_latent"], X, missing_mask)
l_rec = calc_mse(inputs["reconstruction"], X, missing_mask)
HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0))

if (
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/vader/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
PeepholeLSTMCell,
ImplicitImputation,
)
from ....utils.metrics import cal_mse
from ....utils.metrics import calc_mse


class _VaDER(nn.Module):
Expand Down Expand Up @@ -184,7 +184,7 @@ def forward(
}

# calculate the reconstruction loss
unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask)
unscaled_reconstruction_loss = calc_mse(X_reconstructed, X, missing_mask)
reconstruction_loss = (
unscaled_reconstruction_loss
* self.n_steps
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ..base import BaseModel, BaseNNModel
from ..utils.logging import logger
from ..utils.metrics import cal_mse
from ..utils.metrics import calc_mse

try:
import nni
Expand Down Expand Up @@ -299,7 +299,7 @@ def _train_model(
imputation_collector = torch.cat(imputation_collector)
imputation_collector = imputation_collector.cpu().detach().numpy()

mean_val_loss = cal_mse(
mean_val_loss = calc_mse(
imputation_collector,
val_loader.dataset.data["X_intact"],
val_loader.dataset.data["indicating_mask"],
Expand Down
8 changes: 4 additions & 4 deletions pypots/imputation/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .submodules import FeatureRegression
from ....modules.rnn import TemporalDecay
from ....utils.metrics import cal_mae
from ....utils.metrics import calc_mae


class RITS(nn.Module):
Expand Down Expand Up @@ -150,17 +150,17 @@ def impute(

hidden_states = hidden_states * gamma_h # decay hidden states
x_h = self.hist_reg(hidden_states)
reconstruction_loss += cal_mae(x_h, x, m)
reconstruction_loss += calc_mae(x_h, x, m)

x_c = m * x + (1 - m) * x_h

z_h = self.feat_reg(x_c)
reconstruction_loss += cal_mae(z_h, x, m)
reconstruction_loss += calc_mae(z_h, x, m)

alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1)))

c_h = alpha * z_h + (1 - alpha) * x_h
reconstruction_loss += cal_mae(c_h, x, m)
reconstruction_loss += calc_mae(c_h, x, m)

c_c = m * x + (1 - m) * c_h
estimations.append(c_h.unsqueeze(dim=1))
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/csdi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger
from ...utils.metrics import cal_mse
from ...utils.metrics import calc_mse


class CSDI(BaseNNImputer):
Expand Down Expand Up @@ -256,7 +256,7 @@ def _train_model(
imputation_collector = torch.cat(imputation_collector)
imputation_collector = imputation_collector.cpu().detach().numpy()

mean_val_loss = cal_mse(
mean_val_loss = calc_mse(
imputation_collector,
val_loader.dataset.data["X_intact"],
val_loader.dataset.data["indicating_mask"],
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/mrnn/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn as nn

from .submodules import FCN_Regression
from ....utils.metrics import cal_rmse
from ....utils.metrics import calc_rmse


class _MRNN(nn.Module):
Expand Down Expand Up @@ -74,7 +74,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
FCN_estimation = self.fcn_regression(
x, m, RNN_imputed_data
) # FCN estimation is output estimation
reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse(
reconstruction_loss += calc_rmse(FCN_estimation, x, m) + calc_rmse(
RNN_estimation, x, m
)
estimations.append(FCN_estimation.unsqueeze(dim=1))
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/saits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger
from ...utils.metrics import cal_mae
from ...utils.metrics import calc_mae


class SAITS(BaseNNImputer):
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
customized_loss_func: Callable = cal_mae,
customized_loss_func: Callable = calc_mae,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/saits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn.functional as F

from ....modules.transformer import EncoderLayer, PositionalEncoding
from ....utils.metrics import cal_mae
from ....utils.metrics import calc_mae


class _SAITS(nn.Module):
Expand All @@ -39,7 +39,7 @@ def __init__(
diagonal_attention_mask: bool = True,
ORT_weight: float = 1,
MIT_weight: float = 1,
customized_loss_func: Callable = cal_mae,
customized_loss_func: Callable = calc_mae,
):
super().__init__()
self.n_layers = n_layers
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/timesnet/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .embedding import DataEmbedding
from .layer import TimesBlock
from ....utils.metrics import cal_mse
from ....utils.metrics import calc_mse


class _TimesNet(nn.Module):
Expand Down Expand Up @@ -88,7 +88,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:

if training:
# `loss` is always the item for backward propagating to update the model
loss = cal_mse(dec_out, inputs["X_intact"], inputs["indicating_mask"])
loss = calc_mse(dec_out, inputs["X_intact"], inputs["indicating_mask"])
results["loss"] = loss

return results
6 changes: 3 additions & 3 deletions pypots/imputation/transformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn

from ....modules.transformer import EncoderLayer, PositionalEncoding
from ....utils.metrics import cal_mae
from ....utils.metrics import calc_mae


class _TransformerEncoder(nn.Module):
Expand Down Expand Up @@ -89,8 +89,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:

# if in training mode, return results with losses
if training:
ORT_loss = cal_mae(learned_presentation, X, masks)
MIT_loss = cal_mae(
ORT_loss = calc_mae(learned_presentation, X, masks)
MIT_loss = calc_mae(
learned_presentation, inputs["X_intact"], inputs["indicating_mask"]
)
results["ORT_loss"] = ORT_loss
Expand Down
65 changes: 57 additions & 8 deletions pypots/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,29 @@
# License: BSD-3-Clause

from .classification import (
calc_binary_classification_metrics,
calc_precision_recall_f1,
calc_pr_auc,
calc_roc_auc,
calc_acc,
# deprecated
cal_binary_classification_metrics,
cal_precision_recall_f1,
cal_pr_auc,
cal_roc_auc,
cal_acc,
)
from .clustering import (
calc_rand_index,
calc_adjusted_rand_index,
calc_cluster_purity,
calc_nmi,
calc_chs,
calc_dbs,
calc_silhouette,
calc_internal_cluster_validation_metrics,
calc_external_cluster_validation_metrics,
# deprecated
cal_rand_index,
cal_adjusted_rand_index,
cal_cluster_purity,
Expand All @@ -23,21 +39,49 @@
cal_internal_cluster_validation_metrics,
cal_external_cluster_validation_metrics,
)
from .error import cal_mae, cal_mse, cal_rmse, cal_mre
from .error import (
calc_mae,
calc_mse,
calc_rmse,
calc_mre,
calc_quantile_crps,
calc_quantile_crps_sum,
# deprecated
cal_mae,
cal_mse,
cal_rmse,
cal_mre,
)

__all__ = [
# error
"calc_mae",
"calc_mse",
"calc_rmse",
"calc_mre",
"calc_quantile_crps",
"calc_quantile_crps_sum",
# classification
"calc_binary_classification_metrics",
"calc_precision_recall_f1",
"calc_pr_auc",
"calc_roc_auc",
"calc_acc",
# clustering
"calc_rand_index",
"calc_adjusted_rand_index",
"calc_cluster_purity",
"calc_nmi",
"calc_chs",
"calc_dbs",
"calc_silhouette",
"calc_internal_cluster_validation_metrics",
"calc_external_cluster_validation_metrics",
# deprecated
"cal_mae",
"cal_mse",
"cal_rmse",
"cal_mre",
# classification
"cal_binary_classification_metrics",
"cal_precision_recall_f1",
"cal_pr_auc",
"cal_roc_auc",
"cal_acc",
# clustering
"cal_rand_index",
"cal_adjusted_rand_index",
"cal_cluster_purity",
Expand All @@ -47,4 +91,9 @@
"cal_silhouette",
"cal_internal_cluster_validation_metrics",
"cal_external_cluster_validation_metrics",
"cal_binary_classification_metrics",
"cal_precision_recall_f1",
"cal_pr_auc",
"cal_roc_auc",
"cal_acc",
]
Loading

0 comments on commit d504e6b

Please sign in to comment.