Skip to content

Commit

Permalink
Update GP-VAE (#277)
Browse files Browse the repository at this point in the history
* docs: update docs;

* feat: add n_sampling_times for GPVAE predict();

* refactor: use MSE to replace MAE for testing results;
  • Loading branch information
WenjieDu authored Dec 20, 2023
1 parent 88056a4 commit 022ee07
Show file tree
Hide file tree
Showing 14 changed files with 320 additions and 110 deletions.
6 changes: 6 additions & 0 deletions pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class BRITS(BaseNNImputer):
Parameters
----------
n_steps :
The number of time steps in the time-series data sample.
n_features :
The number of features in the time-series data sample.
rnn_hidden_size :
The size of the RNN hidden state, also the number of hidden units in the RNN cell.
Expand Down
201 changes: 196 additions & 5 deletions pypots/imputation/gpvae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,73 @@
# License: BSD-3-Clause


import os
from typing import Union, Optional

import numpy as np
import torch
from torch.utils.data import DataLoader

try:
import nni
except ImportError:
pass


from .data import DatasetForGPVAE
from .modules import _GPVAE
from ..base import BaseNNImputer
from ...data.checking import check_X_ori_in_val_set
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger
from ...utils.metrics import calc_mse


class GPVAE(BaseNNImputer):
"""The PyTorch implementation of the GPVAE model :cite:`fortuin2020GPVAEDeep`.
Parameters
----------
beta: float
The weight of KL divergence in EBLO.
n_steps :
The number of time steps in the time-series data sample.
n_features :
The number of features in the time-series data sample.
latent_size : int,
The feature dimension of the latent embedding
encoder_sizes : tuple,
The tuple of the network size in encoder
decoder_sizes : tuple,
The tuple of the network size in decoder
beta : float,
The weight of KL divergence in ELBO.
M : int,
The number of Monte Carlo samples for ELBO estimation during training.
K : int,
The number of importance weights for IWAE model training loss.
kernel: str
The type of kernel function chosen in the Gaussain Process Proir. ["cauchy", "diffusion", "rbf", "matern"]
sigma : float,
The scale parameter for a kernel function
length_scale : float,
The length scale parameter for a kernel function
kernel_scales : int,
The number of different length scales over latent space dimensions
window_size : int,
Window size for the inference CNN.
batch_size : int
The batch size for training and evaluating the model.
Expand Down Expand Up @@ -199,6 +240,124 @@ def _assemble_input_for_validating(self, data: list) -> dict:
def _assemble_input_for_testing(self, data: list) -> dict:
return self._assemble_input_for_training(data)

def _train_model(
self,
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:
# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float("inf")
self.best_model_dict = None

try:
training_step = 0
for epoch in range(self.epochs):
self.model.train()
epoch_train_loss_collector = []
for idx, data in enumerate(training_loader):
training_step += 1
inputs = self._assemble_input_for_training(data)
self.optimizer.zero_grad()
results = self.model.forward(inputs)
# use sum() before backward() in case of multi-gpu training
results["loss"].sum().backward()
self.optimizer.step()
epoch_train_loss_collector.append(results["loss"].sum().item())

# save training loss logs into the tensorboard file for every step if in need
if self.summary_writer is not None:
self._save_log_into_tb_file(training_step, "training", results)

# mean training loss of the current epoch
mean_train_loss = np.mean(epoch_train_loss_collector)

if val_loader is not None:
self.model.eval()
imputation_loss_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(
inputs, training=False, n_sampling_times=1
)
imputed_data = results["imputed_data"].mean(axis=1)
imputation_mse = (
calc_mse(
imputed_data,
inputs["X_ori"],
inputs["indicating_mask"],
)
.sum()
.detach()
.item()
)
imputation_loss_collector.append(imputation_mse)

mean_val_loss = np.mean(imputation_loss_collector)

# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"imputation_loss": mean_val_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"Epoch {epoch} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)

if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
# save the model if necessary
self._auto_save_model_if_necessary(
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1

if os.getenv("enable_tuning", False):
nni.report_intermediate_result(mean_loss)
if epoch == self.epochs - 1 or self.patience == 0:
nni.report_final_result(self.best_loss)

if self.patience == 0:
logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break

except Exception as e:
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not trained. Please investigate the error printed above."
)
else:
RuntimeWarning(
"Training got interrupted. Please investigate the error printed above.\n"
"Model got trained and will load the best checkpoint so far for testing.\n"
"If you don't want it, please try fit() again."
)

if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")

def fit(
self,
train_set: Union[dict, str],
Expand Down Expand Up @@ -240,8 +399,35 @@ def fit(
def predict(
self,
test_set: Union[dict, str],
file_type="h5py",
file_type: str = "h5py",
n_sampling_times: int = 1,
) -> dict:
"""
Parameters
----------
test_set : dict or str
The dataset for model validating, should be a dictionary including keys as 'X' and 'y',
or a path string locating a data file.
If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
which is time-series data for validating, can contain missing values, and y should be array-like of shape
[n_samples], which is classification labels of X.
If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
file_type : str
The type of the given file if test_set is a path string.
n_sampling_times:
The number of sampling times for the model to produce predictions.
Returns
-------
result_dict: dict
Prediction results in a Python Dictionary for the given samples.
It should be a dictionary including a key named 'imputation'.
"""
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForGPVAE(
test_set, return_X_ori=False, return_labels=False, file_type=file_type
Expand All @@ -257,7 +443,9 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(
inputs, training=False, n_sampling_times=n_sampling_times
)
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

Expand All @@ -271,6 +459,7 @@ def impute(
self,
X: Union[dict, str],
file_type="h5py",
n_sampling_times: int = 1,
) -> np.ndarray:
"""Impute missing values in the given data with the trained model.
Expand All @@ -295,5 +484,7 @@ def impute(
logger.warning(
"🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
)
results_dict = self.predict(X, file_type=file_type)
results_dict = self.predict(
X, file_type=file_type, n_sampling_times=n_sampling_times
)
return results_dict["imputation"]
91 changes: 47 additions & 44 deletions pypots/imputation/gpvae/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,54 +156,57 @@ def _init_prior(self, device="cpu"):
)
return prior

def forward(self, inputs, training=True):
def forward(self, inputs, training=True, n_sampling_times=1):
x = inputs["X"]
n_samples, n_steps, n_features = x.shape
m_mask = inputs["missing_mask"]
x = x.repeat(self.M * self.K, 1, 1)

if self.prior is None:
self.prior = self._init_prior(device=x.device)

if m_mask is not None:
m_mask = m_mask.repeat(self.M * self.K, 1, 1)
m_mask = m_mask.type(torch.bool)

# pz = self.prior()
qz_x = self.encode(x)
z = qz_x.rsample()
px_z = self.decode(z)

nll = -px_z.log_prob(x)
nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll))
if m_mask is not None:
nll = torch.where(m_mask, nll, torch.zeros_like(nll))
nll = nll.sum(dim=(1, 2))

if self.K > 1:
kl = qz_x.log_prob(z) - self.prior.log_prob(z)
kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl))
kl = kl.sum(1)

weights = -nll - kl
weights = torch.reshape(weights, [self.M, self.K, -1])

elbo = torch.logsumexp(weights, dim=1)
elbo = elbo.mean()
else:
kl = self.kl_divergence(qz_x, self.prior)
kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl))
kl = kl.sum(1)

elbo = -nll - self.beta * kl
elbo = elbo.mean()

imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
results = {}
if training:
x = x.repeat(self.K * self.M, 1, 1)
m_mask = m_mask.repeat(self.K * self.M, 1, 1).type(torch.bool)

if self.prior is None:
self.prior = self._init_prior(device=x.device)

qz_x = self.encode(x)
z = qz_x.rsample()
px_z = self.decode(z)
nll = -px_z.log_prob(x)
nll = torch.where(torch.isfinite(nll), nll, torch.zeros_like(nll))
if m_mask is not None:
nll = torch.where(m_mask, nll, torch.zeros_like(nll))
nll = nll.sum(dim=(1, 2))

if self.K > 1:
kl = qz_x.log_prob(z) - self.prior.log_prob(z)
kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl))
kl = kl.sum(1)

weights = -nll - kl
weights = torch.reshape(weights, [self.M, self.K, -1])

elbo = torch.logsumexp(weights, dim=1)
elbo = elbo.mean()
else:
kl = self.kl_divergence(qz_x, self.prior)
kl = torch.where(torch.isfinite(kl), kl, torch.zeros_like(kl))
kl = kl.sum(1)

elbo = -nll - self.beta * kl
elbo = elbo.mean()
results["loss"] = -elbo.mean()
else:
x = x.repeat(n_sampling_times, 1, 1)
m_mask = m_mask.repeat(n_sampling_times, 1, 1).type(torch.bool)
decode_x_mean = self.decode(self.encode(x).mean).mean
imputed_data = decode_x_mean * ~m_mask + x * m_mask
imputed_data = imputed_data.reshape(
n_sampling_times, n_samples, n_steps, n_features
).permute(1, 0, 2, 3)

results = {
"imputed_data": imputed_data,
}

return results
6 changes: 6 additions & 0 deletions pypots/imputation/mrnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ class MRNN(BaseNNImputer):
Parameters
----------
n_steps :
The number of time steps in the time-series data sample.
n_features :
The number of features in the time-series data sample.
rnn_hidden_size :
The size of the RNN hidden state, also the number of hidden units in the RNN cell.
Expand Down
2 changes: 1 addition & 1 deletion tests/data/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def test_0_save_dict_into_h5(self):
def test_0_pickle_dump_load(self):
pickle_dump(self.data_to_save, self.pickle_saving_path)
loaded_data = pickle_load(self.pickle_saving_path)
assert loaded_data == self.data_to_save
assert (loaded_data["c"]["e"]["f"] == self.data_to_save["c"]["e"]["f"]).all()
Loading

0 comments on commit 022ee07

Please sign in to comment.