diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 2fd255d7..514ea821 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -34,6 +34,12 @@ class BRITS(BaseNNImputer): Parameters ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + rnn_hidden_size : The size of the RNN hidden state, also the number of hidden units in the RNN cell. diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index a3ada5ac..9fac3fc2 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -10,12 +10,19 @@ # License: BSD-3-Clause +import os from typing import Union, Optional import numpy as np import torch from torch.utils.data import DataLoader +try: + import nni +except ImportError: + pass + + from .data import DatasetForGPVAE from .modules import _GPVAE from ..base import BaseNNImputer @@ -23,6 +30,7 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger +from ...utils.metrics import calc_mse class GPVAE(BaseNNImputer): @@ -30,12 +38,45 @@ class GPVAE(BaseNNImputer): Parameters ---------- - beta: float - The weight of KL divergence in EBLO. + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + latent_size : int, + The feature dimension of the latent embedding + + encoder_sizes : tuple, + The tuple of the network size in encoder + + decoder_sizes : tuple, + The tuple of the network size in decoder + + beta : float, + The weight of KL divergence in ELBO. + + M : int, + The number of Monte Carlo samples for ELBO estimation during training. + + K : int, + The number of importance weights for IWAE model training loss. kernel: str The type of kernel function chosen in the Gaussain Process Proir. ["cauchy", "diffusion", "rbf", "matern"] + sigma : float, + The scale parameter for a kernel function + + length_scale : float, + The length scale parameter for a kernel function + + kernel_scales : int, + The number of different length scales over latent space dimensions + + window_size : int, + Window size for the inference CNN. + batch_size : int The batch size for training and evaluating the model. @@ -199,6 +240,124 @@ def _assemble_input_for_validating(self, data: list) -> dict: def _assemble_input_for_testing(self, data: list) -> dict: return self._assemble_input_for_training(data) + def _train_model( + self, + training_loader: DataLoader, + val_loader: DataLoader = None, + ) -> None: + # each training starts from the very beginning, so reset the loss and model dict here + self.best_loss = float("inf") + self.best_model_dict = None + + try: + training_step = 0 + for epoch in range(self.epochs): + self.model.train() + epoch_train_loss_collector = [] + for idx, data in enumerate(training_loader): + training_step += 1 + inputs = self._assemble_input_for_training(data) + self.optimizer.zero_grad() + results = self.model.forward(inputs) + # use sum() before backward() in case of multi-gpu training + results["loss"].sum().backward() + self.optimizer.step() + epoch_train_loss_collector.append(results["loss"].sum().item()) + + # save training loss logs into the tensorboard file for every step if in need + if self.summary_writer is not None: + self._save_log_into_tb_file(training_step, "training", results) + + # mean training loss of the current epoch + mean_train_loss = np.mean(epoch_train_loss_collector) + + if val_loader is not None: + self.model.eval() + imputation_loss_collector = [] + with torch.no_grad(): + for idx, data in enumerate(val_loader): + inputs = self._assemble_input_for_validating(data) + results = self.model.forward( + inputs, training=False, n_sampling_times=1 + ) + imputed_data = results["imputed_data"].mean(axis=1) + imputation_mse = ( + calc_mse( + imputed_data, + inputs["X_ori"], + inputs["indicating_mask"], + ) + .sum() + .detach() + .item() + ) + imputation_loss_collector.append(imputation_mse) + + mean_val_loss = np.mean(imputation_loss_collector) + + # save validating loss logs into the tensorboard file for every epoch if in need + if self.summary_writer is not None: + val_loss_dict = { + "imputation_loss": mean_val_loss, + } + self._save_log_into_tb_file(epoch, "validating", val_loss_dict) + + logger.info( + f"Epoch {epoch} - " + f"training loss: {mean_train_loss:.4f}, " + f"validating loss: {mean_val_loss:.4f}" + ) + mean_loss = mean_val_loss + else: + logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}") + mean_loss = mean_train_loss + + if np.isnan(mean_loss): + logger.warning( + f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." + ) + + if mean_loss < self.best_loss: + self.best_loss = mean_loss + self.best_model_dict = self.model.state_dict() + self.patience = self.original_patience + # save the model if necessary + self._auto_save_model_if_necessary( + training_finished=False, + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + ) + else: + self.patience -= 1 + + if os.getenv("enable_tuning", False): + nni.report_intermediate_result(mean_loss) + if epoch == self.epochs - 1 or self.patience == 0: + nni.report_final_result(self.best_loss) + + if self.patience == 0: + logger.info( + "Exceeded the training patience. Terminating the training procedure..." + ) + break + + except Exception as e: + logger.error(f"Exception: {e}") + if self.best_model_dict is None: + raise RuntimeError( + "Training got interrupted. Model was not trained. Please investigate the error printed above." + ) + else: + RuntimeWarning( + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" + "If you don't want it, please try fit() again." + ) + + if np.isnan(self.best_loss): + raise ValueError("Something is wrong. best_loss is Nan after training.") + + logger.info("Finished training.") + def fit( self, train_set: Union[dict, str], @@ -240,8 +399,35 @@ def fit( def predict( self, test_set: Union[dict, str], - file_type="h5py", + file_type: str = "h5py", + n_sampling_times: int = 1, ) -> dict: + """ + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + n_sampling_times: + The number of sampling times for the model to produce predictions. + + Returns + ------- + result_dict: dict + Prediction results in a Python Dictionary for the given samples. + It should be a dictionary including a key named 'imputation'. + + """ self.model.eval() # set the model as eval status to freeze it. test_set = DatasetForGPVAE( test_set, return_X_ori=False, return_labels=False, file_type=file_type @@ -257,7 +443,9 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward( + inputs, training=False, n_sampling_times=n_sampling_times + ) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) @@ -271,6 +459,7 @@ def impute( self, X: Union[dict, str], file_type="h5py", + n_sampling_times: int = 1, ) -> np.ndarray: """Impute missing values in the given data with the trained model. @@ -295,5 +484,7 @@ def impute( logger.warning( "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." ) - results_dict = self.predict(X, file_type=file_type) + results_dict = self.predict( + X, file_type=file_type, n_sampling_times=n_sampling_times + ) return results_dict["imputation"] diff --git a/pypots/imputation/gpvae/modules/core.py b/pypots/imputation/gpvae/modules/core.py index 7a37ffde..c546e915 100644 --- a/pypots/imputation/gpvae/modules/core.py +++ b/pypots/imputation/gpvae/modules/core.py @@ -156,54 +156,57 @@ def _init_prior(self, device="cpu"): ) return prior - def forward(self, inputs, training=True): + def forward(self, inputs, training=True, n_sampling_times=1): x = inputs["X"] + n_samples, n_steps, n_features = x.shape m_mask = inputs["missing_mask"] - x = x.repeat(self.M * self.K, 1, 1) - if self.prior is None: - self.prior = self._init_prior(device=x.device) - - if m_mask is not None: - m_mask = m_mask.repeat(self.M * self.K, 1, 1) - m_mask = m_mask.type(torch.bool) - - # pz = self.prior() - qz_x = self.encode(x) - z = qz_x.rsample() - px_z = self.decode(z) - - nll = -px_z.log_prob(x) - nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll)) - if m_mask is not None: - nll = torch.where(m_mask, nll, torch.zeros_like(nll)) - nll = nll.sum(dim=(1, 2)) - - if self.K > 1: - kl = qz_x.log_prob(z) - self.prior.log_prob(z) - kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) - kl = kl.sum(1) - - weights = -nll - kl - weights = torch.reshape(weights, [self.M, self.K, -1]) - - elbo = torch.logsumexp(weights, dim=1) - elbo = elbo.mean() - else: - kl = self.kl_divergence(qz_x, self.prior) - kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) - kl = kl.sum(1) - - elbo = -nll - self.beta * kl - elbo = elbo.mean() - - imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask - results = { - "imputed_data": imputed_data, - } - - # if in training mode, return results with losses + results = {} if training: + x = x.repeat(self.K * self.M, 1, 1) + m_mask = m_mask.repeat(self.K * self.M, 1, 1).type(torch.bool) + + if self.prior is None: + self.prior = self._init_prior(device=x.device) + + qz_x = self.encode(x) + z = qz_x.rsample() + px_z = self.decode(z) + nll = -px_z.log_prob(x) + nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll)) + if m_mask is not None: + nll = torch.where(m_mask, nll, torch.zeros_like(nll)) + nll = nll.sum(dim=(1, 2)) + + if self.K > 1: + kl = qz_x.log_prob(z) - self.prior.log_prob(z) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + weights = -nll - kl + weights = torch.reshape(weights, [self.M, self.K, -1]) + + elbo = torch.logsumexp(weights, dim=1) + elbo = elbo.mean() + else: + kl = self.kl_divergence(qz_x, self.prior) + kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl)) + kl = kl.sum(1) + + elbo = -nll - self.beta * kl + elbo = elbo.mean() results["loss"] = -elbo.mean() + else: + x = x.repeat(n_sampling_times, 1, 1) + m_mask = m_mask.repeat(n_sampling_times, 1, 1).type(torch.bool) + decode_x_mean = self.decode(self.encode(x).mean).mean + imputed_data = decode_x_mean * ~m_mask + x * m_mask + imputed_data = imputed_data.reshape( + n_sampling_times, n_samples, n_steps, n_features + ).permute(1, 0, 2, 3) + + results = { + "imputed_data": imputed_data, + } return results diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index cbb095c7..ee33f2a8 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -28,6 +28,12 @@ class MRNN(BaseNNImputer): Parameters ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + rnn_hidden_size : The size of the RNN hidden state, also the number of hidden units in the RNN cell. diff --git a/tests/data/saving.py b/tests/data/saving.py index d7c47dee..c52e5e9f 100644 --- a/tests/data/saving.py +++ b/tests/data/saving.py @@ -48,4 +48,4 @@ def test_0_save_dict_into_h5(self): def test_0_pickle_dump_load(self): pickle_dump(self.data_to_save, self.pickle_saving_path) loaded_data = pickle_load(self.pickle_saving_path) - assert loaded_data == self.data_to_save + assert (loaded_data["c"]["e"]["f"] == self.data_to_save["c"]["e"]["f"]).all() diff --git a/tests/imputation/brits.py b/tests/imputation/brits.py index 36efcc7c..2145286b 100644 --- a/tests/imputation/brits.py +++ b/tests/imputation/brits.py @@ -15,7 +15,7 @@ from pypots.imputation import BRITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, EPOCHS, @@ -62,10 +62,10 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"BRITS test_MAE: {test_MAE}") + logger.info(f"BRITS test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-brits") def test_2_parameters(self): @@ -106,12 +106,12 @@ def test_4_lazy_loading(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading BRITS test_MAE: {test_MAE}") + logger.info(f"Lazy-loading BRITS test_MSE: {test_MSE}") if __name__ == "__main__": diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py index 3dc88602..ae2fa2a3 100644 --- a/tests/imputation/csdi.py +++ b/tests/imputation/csdi.py @@ -15,7 +15,7 @@ from pypots.imputation import CSDI from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae, calc_quantile_crps +from pypots.utils.metrics import calc_mse, calc_quantile_crps from tests.global_test_config import ( DATA, EPOCHS, @@ -63,7 +63,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="imputation-csdi") def test_1_impute(self): - imputed_X = self.csdi.predict(TEST_SET)["imputation"] + imputed_X = self.csdi.predict(TEST_SET, n_sampling_times=2)["imputation"] test_CRPS = calc_quantile_crps( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) @@ -71,10 +71,10 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"CSDI test_MAE: {test_MAE}, test_CRPS: {test_CRPS}") + logger.info(f"CSDI test_MSE: {test_MSE}, test_CRPS: {test_CRPS}") @pytest.mark.xdist_group(name="imputation-csdi") def test_2_parameters(self): @@ -120,10 +120,10 @@ def test_4_lazy_loading(self): imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"Lazy-loading CSDI test_MAE: {test_MAE}, test_CRPS: {test_CRPS}") + logger.info(f"Lazy-loading CSDI test_MSE: {test_MSE}, test_CRPS: {test_CRPS}") if __name__ == "__main__": diff --git a/tests/imputation/gpvae.py b/tests/imputation/gpvae.py index 167e0dfc..b9ed017d 100644 --- a/tests/imputation/gpvae.py +++ b/tests/imputation/gpvae.py @@ -15,7 +15,7 @@ from pypots.imputation import GPVAE from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, EPOCHS, @@ -58,14 +58,15 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="imputation-gpvae") def test_1_impute(self): - imputed_X = self.gp_vae.impute(TEST_SET) + imputed_X = self.gp_vae.predict(TEST_SET, n_sampling_times=2)["imputation"] + imputed_X = imputed_X.mean(axis=1) assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"GP-VAE test_MAE: {test_MAE}") + logger.info(f"GP-VAE test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-gpvae") def test_2_parameters(self): @@ -101,17 +102,20 @@ def test_3_saving_path(self): @pytest.mark.xdist_group(name="imputation-gpvae") def test_4_lazy_loading(self): self.gp_vae.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) - imputation_results = self.gp_vae.predict(H5_TEST_SET_PATH) + imputed_X = self.gp_vae.predict(H5_TEST_SET_PATH, n_sampling_times=2)[ + "imputation" + ] + imputed_X = imputed_X.mean(axis=1) assert not np.isnan( - imputation_results["imputation"] + imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( - imputation_results["imputation"], + test_MSE = calc_mse( + imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading GP-VAE test_MAE: {test_MAE}") + logger.info(f"Lazy-loading GP-VAE test_MSE: {test_MSE}") if __name__ == "__main__": diff --git a/tests/imputation/locf.py b/tests/imputation/locf.py index 626623b0..1111b262 100644 --- a/tests/imputation/locf.py +++ b/tests/imputation/locf.py @@ -14,7 +14,7 @@ from pypots.imputation import LOCF from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, DEVICE, @@ -39,32 +39,32 @@ def test_0_impute(self): assert not np.isnan( test_X_imputed_zero ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( test_X_imputed_zero, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"LOCF (zero) test_MAE: {test_MAE}") + logger.info(f"LOCF (zero) test_MSE: {test_MSE}") test_X_imputed_backward = self.locf_backward.predict(TEST_SET)["imputation"] assert not np.isnan( test_X_imputed_backward ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( test_X_imputed_backward, DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"LOCF (backward) test_MAE: {test_MAE}") + logger.info(f"LOCF (backward) test_MSE: {test_MSE}") test_X_imputed_mean = self.locf_mean.predict(TEST_SET)["imputation"] assert not np.isnan( test_X_imputed_mean ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( test_X_imputed_mean, DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"LOCF (mean) test_MAE: {test_MAE}") + logger.info(f"LOCF (mean) test_MSE: {test_MSE}") test_X_imputed_nan = self.locf_nan.predict(TEST_SET)["imputation"] num_of_missing = np.isnan(test_X_imputed_nan).sum() @@ -82,30 +82,30 @@ def test_0_impute(self): assert not torch.isnan( test_X_imputed_zero ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae(test_X_imputed_zero, test_X_ori, test_X_indicating_mask) - logger.info(f"LOCF (zero) test_MAE: {test_MAE}") + test_MSE = calc_mse(test_X_imputed_zero, test_X_ori, test_X_indicating_mask) + logger.info(f"LOCF (zero) test_MSE: {test_MSE}") test_X_imputed_backward = self.locf_backward.predict({"X": X})["imputation"] assert not torch.isnan( test_X_imputed_backward ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( test_X_imputed_backward, test_X_ori, test_X_indicating_mask, ) - logger.info(f"LOCF (backward) test_MAE: {test_MAE}") + logger.info(f"LOCF (backward) test_MSE: {test_MSE}") test_X_imputed_mean = self.locf_mean.predict({"X": X})["imputation"] assert not torch.isnan( test_X_imputed_mean ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( test_X_imputed_mean, test_X_ori, test_X_indicating_mask, ) - logger.info(f"LOCF (mean) test_MAE: {test_MAE}") + logger.info(f"LOCF (mean) test_MSE: {test_MSE}") test_X_imputed_nan = self.locf_nan.predict({"X": X})["imputation"] num_of_missing = torch.isnan(test_X_imputed_nan).sum() @@ -120,12 +120,12 @@ def test_4_lazy_loading(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading LOCF test_MAE: {test_MAE}") + logger.info(f"Lazy-loading LOCF test_MSE: {test_MSE}") if __name__ == "__main__": diff --git a/tests/imputation/mrnn.py b/tests/imputation/mrnn.py index 3649d27b..d1d6003d 100644 --- a/tests/imputation/mrnn.py +++ b/tests/imputation/mrnn.py @@ -15,7 +15,7 @@ from pypots.imputation import MRNN from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, EPOCHS, @@ -62,10 +62,10 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"MRNN test_MAE: {test_MAE}") + logger.info(f"MRNN test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-mrnn") def test_2_parameters(self): @@ -106,12 +106,12 @@ def test_4_lazy_loading(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading MRNN test_MAE: {test_MAE}") + logger.info(f"Lazy-loading MRNN test_MSE: {test_MSE}") if __name__ == "__main__": diff --git a/tests/imputation/saits.py b/tests/imputation/saits.py index 33c6f6cd..842180f4 100644 --- a/tests/imputation/saits.py +++ b/tests/imputation/saits.py @@ -15,7 +15,7 @@ from pypots.imputation import SAITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, EPOCHS, @@ -72,12 +72,12 @@ def test_1_impute(self): "latent_vars" in imputation_results.keys() ), "Latent variables are not returned thought `return_latent_vars` is set as True." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"SAITS test_MAE: {test_MAE}") + logger.info(f"SAITS test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-saits") def test_2_parameters(self): @@ -118,12 +118,12 @@ def test_4_lazy_loading(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading SAITS test_MAE: {test_MAE}") + logger.info(f"Lazy-loading SAITS test_MSE: {test_MSE}") if __name__ == "__main__": diff --git a/tests/imputation/timesnet.py b/tests/imputation/timesnet.py index 34f9da03..af35ae08 100644 --- a/tests/imputation/timesnet.py +++ b/tests/imputation/timesnet.py @@ -15,7 +15,7 @@ from pypots.imputation import TimesNet from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, EPOCHS, @@ -68,12 +68,12 @@ def test_1_impute(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"TimesNet test_MAE: {test_MAE}") + logger.info(f"TimesNet test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-timesnet") def test_2_parameters(self): @@ -116,12 +116,12 @@ def test_4_lazy_loading(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading TimesNet test_MAE: {test_MAE}") + logger.info(f"Lazy-loading TimesNet test_MSE: {test_MSE}") if __name__ == "__main__": diff --git a/tests/imputation/transformer.py b/tests/imputation/transformer.py index 88e02802..a4a62457 100644 --- a/tests/imputation/transformer.py +++ b/tests/imputation/transformer.py @@ -15,7 +15,7 @@ from pypots.imputation import Transformer from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, EPOCHS, @@ -68,10 +68,10 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"Transformer test_MAE: {test_MAE}") + logger.info(f"Transformer test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-transformer") def test_2_parameters(self): @@ -115,12 +115,12 @@ def test_4_lazy_loading(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading Transformer test_MAE: {test_MAE}") + logger.info(f"Lazy-loading Transformer test_MSE: {test_MSE}") if __name__ == "__main__": diff --git a/tests/imputation/usgan.py b/tests/imputation/usgan.py index af99b4e7..9ef33139 100644 --- a/tests/imputation/usgan.py +++ b/tests/imputation/usgan.py @@ -15,7 +15,7 @@ from pypots.imputation import USGAN from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import calc_mae +from pypots.utils.metrics import calc_mse from tests.global_test_config import ( DATA, EPOCHS, @@ -64,10 +64,10 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"] ) - logger.info(f"US-GAN test_MAE: {test_MAE}") + logger.info(f"US-GAN test_MSE: {test_MSE}") @pytest.mark.xdist_group(name="imputation-usgan") def test_2_parameters(self): @@ -109,12 +109,12 @@ def test_4_lazy_loading(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = calc_mae( + test_MSE = calc_mse( imputation_results["imputation"], DATA["test_X_ori"], DATA["test_X_indicating_mask"], ) - logger.info(f"Lazy-loading US-GAN test_MAE: {test_MAE}") + logger.info(f"Lazy-loading US-GAN test_MSE: {test_MSE}") if __name__ == "__main__":