diff --git a/README.md b/README.md index 7df93dfe..86e52e6f 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,7 @@ This functionality is implemented with the [Microsoft NNI](https://github.com/mi | Neural Net | DLinear | Are Transformers Effective for Time Series Forecasting? [^17] | 2023 | | Neural Net | ETSformer | Exponential Smoothing Transformers for Time-series Forecasting [^19] | 2023 | | Neural Net | FEDformer | Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [^20] | 2022 | +| Neural Net | Informer | Beyond Efficient Transformer for Long Sequence Time-Series Forecasting [^21] | 2021 | | Neural Net | Autoformer | Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [^15] | 2021 | | Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 | | Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 | @@ -332,6 +333,8 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together [^18]: Nie, Y., Nguyen, N. H., Sinthong, P., & Kalagnanam, J. (2023). [A time series is worth 64 words: Long-term forecasting with transformers](https://openreview.net/forum?id=Jbdc0vTOcol). *ICLR 2023* [^19]: Woo, G., Liu, C., Sahoo, D., Kumar, A., & Hoi, S. (2023). [ETSformer: Exponential Smoothing Transformers for Time-series Forecasting](https://openreview.net/forum?id=5m_3whfo483). *ICLR 2023* [^20]: Zhou, T., Ma, Z., Wen, Q., Wang, X., Sun, L., & Jin, R. (2022). [FEDformer: Frequency enhanced decomposed transformer for long-term series forecasting](https://proceedings.mlr.press/v162/zhou22g.html). *ICML 2022*. +[^21]: Zhou, H., Zhang, S., Peng, J., Zhang, S., Li, J., Xiong, H., & Zhang, W. (2021). [Informer: Beyond efficient transformer for long sequence time-series forecasting](https://ojs.aaai.org/index.php/AAAI/article/view/17325). *AAAI 2021*. +
diff --git a/docs/index.rst b/docs/index.rst index e600ebb2..265c97b1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -181,6 +181,7 @@ Imputation Neural Net PatchTST (A Time Series is Worth Imputation Neural Net DLinear (Are transformers effective for time series forecasting?) 2023 :cite:`zeng2023dlinear` Imputation Neural Net ETSformer (Exponential Smoothing Transformers for Time-series Forecasting) 2023 :cite:`woo2023etsformer` Imputation Neural Net FEDformer (Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting) 2022 :cite:`zhou2022fedformer` +Imputation Neural Net Informer (Beyond Efficient Transformer for Long Sequence Time-Series Forecasting) 2021 :cite:`zhou2021informer` Imputation Neural Net Autoformer (Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting) 2021 :cite:`wu2021autoformer` Imputation Neural Net US-GAN (Unsupervised GAN for Multivariate Time Series Imputation) 2021 :cite:`miao2021SSGAN` Imputation Neural Net CSDI (Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation) 2021 :cite:`tashiro2021csdi` diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index cd4d156a..1893591b 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -73,6 +73,15 @@ pypots.imputation.fedformer :show-inheritance: :inherited-members: +pypots.imputation.informer +------------------------------ + +.. automodule:: pypots.imputation.informer + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.autoformer ------------------------------ diff --git a/docs/references.bib b/docs/references.bib index e4624d06..0d6eeabd 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -544,3 +544,13 @@ @inproceedings{zhou2022fedformer pdf = {https://proceedings.mlr.press/v162/zhou22g/zhou22g.pdf}, url = {https://proceedings.mlr.press/v162/zhou22g.html}, } + +@inproceedings{zhou2021informer, +title={Informer: Beyond efficient transformer for long sequence time-series forecasting}, +author={Zhou, Haoyi and Zhang, Shanghang and Peng, Jieqi and Zhang, Shuai and Li, Jianxin and Xiong, Hui and Zhang, Wancai}, +booktitle={Proceedings of the AAAI conference on artificial intelligence}, +volume={35}, +number={12}, +pages={11106--11115}, +year={2021} +} diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index a7052dcc..0d4e2184 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -16,6 +16,7 @@ from .etsformer import ETSformer from .fedformer import FEDformer from .crossformer import Crossformer +from .informer import Informer from .autoformer import Autoformer from .dlinear import DLinear from .patchtst import PatchTST @@ -36,6 +37,7 @@ "TimesNet", "PatchTST", "DLinear", + "Informer", "Autoformer", "BRITS", "MRNN", diff --git a/pypots/imputation/autoformer/modules/core.py b/pypots/imputation/autoformer/modules/core.py index d2d19f23..c3747fde 100644 --- a/pypots/imputation/autoformer/modules/core.py +++ b/pypots/imputation/autoformer/modules/core.py @@ -10,10 +10,10 @@ from .submodules import ( SeasonalLayerNorm, AutoformerEncoderLayer, - AutoformerEncoder, AutoCorrelation, AutoCorrelationLayer, ) +from ...informer.modules.submodules import InformerEncoder from ....nn.modules.transformer.embedding import DataEmbedding from ....utils.metrics import calc_mse @@ -43,7 +43,7 @@ def __init__( dropout=dropout, with_pos=False, ) - self.encoder = AutoformerEncoder( + self.encoder = InformerEncoder( [ AutoformerEncoderLayer( AutoCorrelationLayer( diff --git a/pypots/imputation/autoformer/modules/submodules.py b/pypots/imputation/autoformer/modules/submodules.py index 40665d05..6eb3d9e2 100644 --- a/pypots/imputation/autoformer/modules/submodules.py +++ b/pypots/imputation/autoformer/modules/submodules.py @@ -285,35 +285,6 @@ def forward(self, x, attn_mask=None): return res, attn -class AutoformerEncoder(nn.Module): - def __init__(self, attn_layers, conv_layers=None, norm_layer=None): - super().__init__() - self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = ( - nn.ModuleList(conv_layers) if conv_layers is not None else None - ) - self.norm = norm_layer - - def forward(self, x, attn_mask=None): - attns = [] - if self.conv_layers is not None: - for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): - x, attn = attn_layer(x, attn_mask=attn_mask) - x = conv_layer(x) - attns.append(attn) - x, attn = self.attn_layers[-1](x) - attns.append(attn) - else: - for attn_layer in self.attn_layers: - x, attn = attn_layer(x, attn_mask=attn_mask) - attns.append(attn) - - if self.norm is not None: - x = self.norm(x) - - return x, attns - - class AutoformerDecoderLayer(nn.Module): """ Autoformer decoder layer with the progressive decomposition architecture @@ -372,23 +343,3 @@ def forward(self, x, cross, x_mask=None, cross_mask=None): 1, 2 ) return x, residual_trend - - -class AutoformerDecoder(nn.Module): - def __init__(self, layers, norm_layer=None, projection=None): - super().__init__() - self.layers = nn.ModuleList(layers) - self.norm = norm_layer - self.projection = projection - - def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): - for layer in self.layers: - x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) - trend = trend + residual_trend - - if self.norm is not None: - x = self.norm(x) - - if self.projection is not None: - x = self.projection(x) - return x, trend diff --git a/pypots/imputation/crossformer/modules/submodules.py b/pypots/imputation/crossformer/modules/submodules.py index 0df19b81..2a67a227 100644 --- a/pypots/imputation/crossformer/modules/submodules.py +++ b/pypots/imputation/crossformer/modules/submodules.py @@ -9,7 +9,7 @@ import torch.nn as nn from einops import rearrange, repeat -from ....nn.modules.transformer import MultiHeadAttention +from ....nn.modules.transformer import ScaledDotProductAttention, MultiHeadAttention class TwoStageAttentionLayer(nn.Module): @@ -33,10 +33,26 @@ def __init__( super().__init__() d_ff = 4 * d_model if d_ff is None else d_ff self.time_attention = MultiHeadAttention( - n_heads, d_model, d_k, d_v, attn_dropout + n_heads, + d_model, + d_k, + d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), + ) + self.dim_sender = MultiHeadAttention( + n_heads, + d_model, + d_k, + d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), + ) + self.dim_receiver = MultiHeadAttention( + n_heads, + d_model, + d_k, + d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), ) - self.dim_sender = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) - self.dim_receiver = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) self.dropout = nn.Dropout(dropout) diff --git a/pypots/imputation/fedformer/modules/core.py b/pypots/imputation/fedformer/modules/core.py index 0be8b14f..895cf8d4 100644 --- a/pypots/imputation/fedformer/modules/core.py +++ b/pypots/imputation/fedformer/modules/core.py @@ -9,11 +9,11 @@ from .submodules import MultiWaveletTransform, FourierBlock from ...autoformer.modules.submodules import ( - AutoformerEncoder, AutoformerEncoderLayer, AutoCorrelationLayer, SeasonalLayerNorm, ) +from ...informer.modules.submodules import InformerEncoder from ....nn.modules.transformer.embedding import DataEmbedding from ....utils.metrics import calc_mse @@ -57,7 +57,7 @@ def __init__( f"Unsupported version: {version}. Please choose from ['Wavelets', 'Fourier']." ) - self.encoder = AutoformerEncoder( + self.encoder = InformerEncoder( [ AutoformerEncoderLayer( AutoCorrelationLayer( diff --git a/pypots/imputation/informer/__init__.py b/pypots/imputation/informer/__init__.py new file mode 100644 index 00000000..df52eeb2 --- /dev/null +++ b/pypots/imputation/informer/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model Informer. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +Informer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import Informer + +__all__ = [ + "Informer", +] diff --git a/pypots/imputation/informer/data.py b/pypots/imputation/informer/data.py new file mode 100644 index 00000000..bf6a146d --- /dev/null +++ b/pypots/imputation/informer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for Informer. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForInformer(DatasetForSAITS): + """Actually Informer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/informer/model.py b/pypots/imputation/informer/model.py new file mode 100644 index 00000000..a1e70da8 --- /dev/null +++ b/pypots/imputation/informer/model.py @@ -0,0 +1,323 @@ +""" +The implementation of Informer for the partially-observed time-series imputation task. + +Refer to the paper "Zhou, H., Zhang, S., Peng, J., Zhang, S., Li, J., Xiong, H., & Zhang, W. (2021). +Informer: Beyond efficient transformer for long sequence time-series forecasting. AAAI 2021". + +Notes +----- +Partial implementation uses code from https://github.com/zhouhaoyi/Informer2020 + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForInformer +from .modules.core import _Informer +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class Informer(BaseNNImputer): + """The PyTorch implementation of the Informer model. + Informer is originally proposed by Wu et al. in :cite:`zhou2021informer`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the Informer model. + + n_heads : + The number of heads in each layer of Informer. + + d_model : + The dimension of the model. + + d_ffn : + The dimension of the feed-forward network. + + factor : + The factor of the auto correlation mechanism for the Informer model. + + moving_avg_window_size : + The window size of moving average. + + dropout : + The dropout rate for the model. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + References + ---------- + .. [1] `Zhou, Haoyi, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang. + "Informer: Beyond efficient transformer for long sequence time-series forecasting." + In Proceedings of the AAAI conference on artificial intelligence, vol. 35, no. 12, pp. 11106-11115. 2021. + `_ + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + n_heads: int, + d_model: int, + d_ffn: int, + factor: int, + dropout: float = 0, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_heads = n_heads + self.n_layers = n_layers + self.d_model = d_model + self.d_ffn = d_ffn + self.factor = factor + self.dropout = dropout + + # set up the model + self.model = _Informer( + self.n_steps, + self.n_features, + self.n_layers, + self.n_heads, + self.d_model, + self.d_ffn, + self.factor, + self.dropout, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForInformer( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForInformer( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/informer/modules/__init__.py b/pypots/imputation/informer/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/informer/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/informer/modules/core.py b/pypots/imputation/informer/modules/core.py new file mode 100644 index 00000000..455a7b1a --- /dev/null +++ b/pypots/imputation/informer/modules/core.py @@ -0,0 +1,86 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from .submodules import ProbAttention, ConvLayer, InformerEncoderLayer, InformerEncoder +from ....nn.modules.transformer.embedding import DataEmbedding +from ....nn.modules.transformer import MultiHeadAttention +from ....utils.metrics import calc_mse + + +class _Informer(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_layers, + n_heads, + d_model, + d_ffn, + factor, + dropout, + distil=False, + activation="relu", + output_attention=False, + ): + super().__init__() + + self.seq_len = n_steps + self.n_layers = n_layers + self.enc_embedding = DataEmbedding( + n_features, + d_model, + dropout=dropout, + ) + self.encoder = InformerEncoder( + [ + InformerEncoderLayer( + MultiHeadAttention( + n_heads, + d_model, + d_model // n_heads, + d_model // n_heads, + ProbAttention(False, factor, dropout, output_attention), + ), + d_model, + d_ffn, + dropout, + activation, + ) + for _ in range(n_layers) + ], + [ConvLayer(d_model) for _ in range(n_layers - 1)] if distil else None, + norm_layer=nn.LayerNorm(d_model), + ) + + # for the imputation task, the output dim is the same as input dim + self.projection = nn.Linear(d_model, n_features) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # embedding + enc_out = self.enc_embedding(X) + + # Informer encoder processing + enc_out, attns = self.encoder(enc_out) + + # project back the original data space + dec_out = self.projection(enc_out) + + imputed_data = masks * X + (1 - masks) * dec_out + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/informer/modules/submodules.py b/pypots/imputation/informer/modules/submodules.py new file mode 100644 index 00000000..8465feeb --- /dev/null +++ b/pypots/imputation/informer/modules/submodules.py @@ -0,0 +1,240 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from math import sqrt +from typing import Optional + +import numpy as np +import torch +import torch.fft +import torch.nn as nn +import torch.nn.functional as F + +from ....nn.modules.transformer.attention import AttentionOperator + + +class ProbMask: + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : + ].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask + + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super().__init__() + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=padding, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class ProbAttention(AttentionOperator): + def __init__( + self, + mask_flag=True, + factor=5, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super().__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + index_sample = torch.randint( + L_K, (L_Q, sample_k) + ) # real U = U_part(factor*ln(L_k))*L_q + K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze( + -2 + ) + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, : + ] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() + else: # use mask + assert L_Q == L_V # requires that L_Q == L_V, i.e. for self-attention only + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : + ] = torch.matmul(attn, V).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) + attns[ + torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, : + ] = attn + return (context_in, attns) + else: + return (context_in, None) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + + B, L_Q, H, D = q.shape + _, L_K, _, _ = k.shape + + q = q.transpose(2, 1) + k = k.transpose(2, 1) + v = v.transpose(2, 1) + + U_part = self.factor * np.ceil(np.log(L_K)).astype("int").item() # c*ln(L_k) + u = self.factor * np.ceil(np.log(L_Q)).astype("int").item() # c*ln(L_q) + + U_part = U_part if U_part < L_K else L_K + u = u if u < L_Q else L_Q + + scores_top, index = self._prob_QK(q, k, sample_k=U_part, n_top=u) + + # add scale factor + scale = self.scale or 1.0 / sqrt(D) + if scale is not None: + scores_top = scores_top * scale + # get the context + context = self._get_initial_context(v, L_Q) + # update the context with selected top_k queries + context, attn = self._update_context( + context, v, scores_top, index, L_Q, attn_mask + ) + + return context.transpose(2, 1).contiguous(), attn + + +class InformerEncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super().__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class InformerEncoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super().__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class InformerDecoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super().__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): + for layer in self.layers: + x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + trend = trend + residual_trend + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x, trend diff --git a/pypots/imputation/patchtst/modules/core.py b/pypots/imputation/patchtst/modules/core.py index 1ba4206d..9013a802 100644 --- a/pypots/imputation/patchtst/modules/core.py +++ b/pypots/imputation/patchtst/modules/core.py @@ -8,6 +8,7 @@ import torch.nn as nn from .submodules import PatchEmbedding, FlattenHead +from ....nn.modules.transformer.attention import ScaledDotProductAttention from ....nn.modules.transformer.auto_encoder import EncoderLayer from ....utils.metrics import calc_mse @@ -49,8 +50,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py index b0a4f1c3..4976c594 100644 --- a/pypots/imputation/saits/modules/core.py +++ b/pypots/imputation/saits/modules/core.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from ....nn.modules.transformer import EncoderLayer, PositionalEncoding +from ....nn.modules.transformer.attention import ScaledDotProductAttention from ....utils.metrics import calc_mae @@ -59,8 +60,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] @@ -73,8 +74,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py index 066b7790..f4cfb841 100644 --- a/pypots/imputation/transformer/modules/core.py +++ b/pypots/imputation/transformer/modules/core.py @@ -19,6 +19,7 @@ import torch.nn as nn from ....nn.modules.transformer import EncoderLayer, PositionalEncoding +from ....nn.modules.transformer.attention import ScaledDotProductAttention from ....utils.metrics import calc_mae @@ -52,8 +53,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/modules/__init__.py b/pypots/modules/__init__.py deleted file mode 100644 index 638464fe..00000000 --- a/pypots/modules/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Everything used to be in this package has been moved to pypots.nn.modules. -This package is kept for backward compatibility and will be removed in the future. -""" - -# Created by Wenjie Du -# License: BSD-3-Clause - -from ..utils.logging import logger - -logger.warning( - "🚨 pypots.modules package has been moved to pypots.nn.modules. " - "Please import everything from pypots.nn.modules instead." -) diff --git a/pypots/nn/modules/transformer/attention.py b/pypots/nn/modules/transformer/attention.py index 89684473..1c23efd8 100644 --- a/pypots/nn/modules/transformer/attention.py +++ b/pypots/nn/modules/transformer/attention.py @@ -16,9 +16,30 @@ import torch import torch.nn as nn import torch.nn.functional as F +from abc import abstractmethod -class ScaledDotProductAttention(nn.Module): +class AttentionOperator(nn.Module): + """ + The abstract class for all attention layers. + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + +class ScaledDotProductAttention(AttentionOperator): """Scaled dot-product attention. Parameters @@ -44,6 +65,7 @@ def forward( k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward processing of the scaled dot-product attention. @@ -51,8 +73,10 @@ def forward( ---------- q: Query tensor. + k: Key tensor. + v: Value tensor. @@ -106,11 +130,8 @@ class MultiHeadAttention(nn.Module): d_v: The dimension of the value tensor. - attn_dropout: - The dropout rate for the attention map. - - attn_temperature: - The temperature for scaling. Default is None, which means d_k**0.5 will be applied. + attention_operator: + The attention operator, e.g. the self-attention proposed in Transformer. """ @@ -120,13 +141,10 @@ def __init__( d_model: int, d_k: int, d_v: int, - attn_dropout: float, - attn_temperature: float = None, + attention_operator: AttentionOperator, ): super().__init__() - attn_temperature = d_k**0.5 if attn_temperature is None else attn_temperature - self.n_heads = n_heads self.d_k = d_k self.d_v = d_v @@ -135,7 +153,7 @@ def __init__( self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False) - self.attention = ScaledDotProductAttention(attn_temperature, attn_dropout) + self.attention_operator = attention_operator self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) def forward( @@ -144,6 +162,7 @@ def forward( k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor], + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward processing of the multi-head attention module. @@ -189,7 +208,7 @@ def forward( # broadcasting on the head axis attn_mask = attn_mask.unsqueeze(1) - v, attn_weights = self.attention(q, k, v, attn_mask) + v, attn_weights = self.attention_operator(q, k, v, attn_mask, **kwargs) # transpose back -> [batch_size, n_steps, n_heads, d_v] # then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v] diff --git a/pypots/nn/modules/transformer/auto_encoder.py b/pypots/nn/modules/transformer/auto_encoder.py index 76761ce3..6aa6e1f2 100644 --- a/pypots/nn/modules/transformer/auto_encoder.py +++ b/pypots/nn/modules/transformer/auto_encoder.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn +from .attention import ScaledDotProductAttention from .embedding import PositionalEncoding from .layers import EncoderLayer, DecoderLayer @@ -78,8 +79,8 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] @@ -190,8 +191,9 @@ def __init__( n_heads, d_k, d_v, + ScaledDotProductAttention(d_k**0.5, attn_dropout), + ScaledDotProductAttention(d_k**0.5, attn_dropout), dropout, - attn_dropout, ) for _ in range(n_layers) ] diff --git a/pypots/nn/modules/transformer/layers.py b/pypots/nn/modules/transformer/layers.py index a5a558cc..e66b4b32 100644 --- a/pypots/nn/modules/transformer/layers.py +++ b/pypots/nn/modules/transformer/layers.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F -from .attention import MultiHeadAttention +from .attention import MultiHeadAttention, AttentionOperator class PositionWiseFeedForward(nn.Module): @@ -85,11 +85,12 @@ class EncoderLayer(nn.Module): d_v: The dimension of the value tensor. + slf_attn_opt: + The attention operator for the self multi-head attention module in the encoder layer. + dropout: The dropout rate. - attn_dropout: - The dropout rate for the attention map. """ def __init__( @@ -99,11 +100,11 @@ def __init__( n_heads: int, d_k: int, d_v: int, + slf_attn_opt: AttentionOperator, dropout: float = 0.1, - attn_dropout: float = 0.1, ): super().__init__() - self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout) @@ -112,6 +113,7 @@ def forward( self, enc_input: torch.Tensor, src_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward processing of the encoder layer. @@ -137,6 +139,7 @@ def forward( enc_input, enc_input, attn_mask=src_mask, + **kwargs, ) # apply dropout and residual connection @@ -170,12 +173,15 @@ class DecoderLayer(nn.Module): d_v: The dimension of the value tensor. + slf_attn_opt: + The attention operator for the self multi-head attention module in the decoder layer. + + enc_attn_opt: + The attention operator for the encoding multi-head attention module in the decoder layer. + dropout: The dropout rate. - attn_dropout: - The dropout rate for the attention map. - """ def __init__( @@ -185,12 +191,13 @@ def __init__( n_heads: int, d_k: int, d_v: int, + slf_attn_opt: AttentionOperator, + enc_attn_opt: AttentionOperator, dropout: float = 0.1, - attn_dropout: float = 0.1, ): super().__init__() - self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) - self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, attn_dropout) + self.slf_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, slf_attn_opt) + self.enc_attn = MultiHeadAttention(n_heads, d_model, d_k, d_v, enc_attn_opt) self.pos_ffn = PositionWiseFeedForward(d_model, d_ffn, dropout) def forward( @@ -199,6 +206,7 @@ def forward( enc_output: torch.Tensor, slf_attn_mask: Optional[torch.Tensor] = None, dec_enc_attn_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward processing of the decoder layer. @@ -231,10 +239,18 @@ def forward( """ dec_output, dec_slf_attn = self.slf_attn( - dec_input, dec_input, dec_input, attn_mask=slf_attn_mask + dec_input, + dec_input, + dec_input, + attn_mask=slf_attn_mask, + **kwargs, ) dec_output, dec_enc_attn = self.enc_attn( - dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask + dec_output, + enc_output, + enc_output, + attn_mask=dec_enc_attn_mask, + **kwargs, ) dec_output = self.pos_ffn(dec_output) return dec_output, dec_slf_attn, dec_enc_attn diff --git a/tests/classification/brits.py b/tests/classification/brits.py index a68442c0..0ec7b68d 100644 --- a/tests/classification/brits.py +++ b/tests/classification/brits.py @@ -44,7 +44,7 @@ class TestBRITS(unittest.TestCase): DATA["n_steps"], DATA["n_features"], n_classes=DATA["n_classes"], - rnn_hidden_size=256, + rnn_hidden_size=32, epochs=EPOCHS, saving_path=saving_path, model_saving_strategy="better", diff --git a/tests/classification/grud.py b/tests/classification/grud.py index 756451d4..5c165e07 100644 --- a/tests/classification/grud.py +++ b/tests/classification/grud.py @@ -44,7 +44,7 @@ class TestGRUD(unittest.TestCase): DATA["n_steps"], DATA["n_features"], n_classes=DATA["n_classes"], - rnn_hidden_size=256, + rnn_hidden_size=32, epochs=EPOCHS, saving_path=saving_path, optimizer=optimizer, diff --git a/tests/classification/raindrop.py b/tests/classification/raindrop.py index 56c31c83..64f6aa59 100644 --- a/tests/classification/raindrop.py +++ b/tests/classification/raindrop.py @@ -42,7 +42,7 @@ class TestRaindrop(unittest.TestCase): DATA["n_classes"], n_layers=2, d_model=DATA["n_features"] * 4, - d_ffn=256, + d_ffn=32, n_heads=2, dropout=0.3, d_static=0, diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 6a36d670..3e29b9ca 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -50,7 +50,7 @@ class TestCRLI(unittest.TestCase): n_features=DATA["n_features"], n_clusters=DATA["n_classes"], n_generator_layers=2, - rnn_hidden_size=128, + rnn_hidden_size=32, rnn_cell_type="GRU", epochs=EPOCHS, saving_path=saving_path, diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index e3a7e334..bf0b0989 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -49,7 +49,7 @@ class TestVaDER(unittest.TestCase): n_steps=DATA["n_steps"], n_features=DATA["n_features"], n_clusters=DATA["n_classes"], - rnn_hidden_size=64, + rnn_hidden_size=32, d_mu_stddev=5, pretrain_epochs=20, epochs=EPOCHS, diff --git a/tests/global_test_config.py b/tests/global_test_config.py index 4b3ac41f..e72cdd20 100644 --- a/tests/global_test_config.py +++ b/tests/global_test_config.py @@ -51,7 +51,7 @@ RESULT_SAVING_DIR_FOR_FORECASTING = os.path.join(RESULT_SAVING_DIR, "forecasting") # set the number of epochs for all model training -EPOCHS = 5 +EPOCHS = 2 # set DEVICES to None if no cuda device is available, to avoid initialization failed while importing test classes n_cuda_devices = torch.cuda.device_count() diff --git a/tests/imputation/autoformer.py b/tests/imputation/autoformer.py index a70ecb1e..83610812 100644 --- a/tests/imputation/autoformer.py +++ b/tests/imputation/autoformer.py @@ -47,8 +47,8 @@ class TestAutoformer(unittest.TestCase): DATA["n_features"], n_layers=2, n_heads=2, - d_model=128, - d_ffn=256, + d_model=32, + d_ffn=32, factor=3, moving_avg_window_size=3, dropout=0, diff --git a/tests/imputation/brits.py b/tests/imputation/brits.py index 2145286b..1e63ffa4 100644 --- a/tests/imputation/brits.py +++ b/tests/imputation/brits.py @@ -45,7 +45,7 @@ class TestBRITS(unittest.TestCase): brits = BRITS( DATA["n_steps"], DATA["n_features"], - 256, + 32, epochs=EPOCHS, saving_path=saving_path, optimizer=optimizer, diff --git a/tests/imputation/crossformer.py b/tests/imputation/crossformer.py index fe8f6467..a6a6c55e 100644 --- a/tests/imputation/crossformer.py +++ b/tests/imputation/crossformer.py @@ -47,8 +47,8 @@ class TestCrossformer(unittest.TestCase): DATA["n_features"], n_layers=2, n_heads=2, - d_model=128, - d_ffn=256, + d_model=32, + d_ffn=32, factor=10, seg_len=12, win_size=2, diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py index ae2fa2a3..a0ee0f93 100644 --- a/tests/imputation/csdi.py +++ b/tests/imputation/csdi.py @@ -49,7 +49,7 @@ class TestCSDI(unittest.TestCase): d_time_embedding=32, d_feature_embedding=3, d_diffusion_embedding=32, - n_diffusion_steps=10, + n_diffusion_steps=5, n_heads=1, epochs=EPOCHS, saving_path=saving_path, diff --git a/tests/imputation/etsformer.py b/tests/imputation/etsformer.py index c098b79f..87b8ce49 100644 --- a/tests/imputation/etsformer.py +++ b/tests/imputation/etsformer.py @@ -48,8 +48,8 @@ class TestETSformer(unittest.TestCase): n_e_layers=2, n_d_layers=2, n_heads=2, - d_model=128, - d_ffn=256, + d_model=32, + d_ffn=32, top_k=3, dropout=0, epochs=EPOCHS, diff --git a/tests/imputation/fedformer.py b/tests/imputation/fedformer.py index 7c8b4b04..7a6b24e5 100644 --- a/tests/imputation/fedformer.py +++ b/tests/imputation/fedformer.py @@ -47,8 +47,8 @@ class TestFEDformer(unittest.TestCase): DATA["n_features"], n_layers=1, n_heads=2, - d_model=64, - d_ffn=64, + d_model=32, + d_ffn=32, moving_avg_window_size=3, dropout=0, version="Wavelets", # TODO: Fourier version does not work diff --git a/tests/imputation/gpvae.py b/tests/imputation/gpvae.py index b9ed017d..9db47e7e 100644 --- a/tests/imputation/gpvae.py +++ b/tests/imputation/gpvae.py @@ -45,7 +45,7 @@ class TestGPVAE(unittest.TestCase): gp_vae = GPVAE( DATA["n_steps"], DATA["n_features"], - 256, + 32, epochs=EPOCHS, saving_path=saving_path, optimizer=optimizer, diff --git a/tests/imputation/informer.py b/tests/imputation/informer.py new file mode 100644 index 00000000..6f13680b --- /dev/null +++ b/tests/imputation/informer.py @@ -0,0 +1,128 @@ +""" +Test cases for Informer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import Informer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestInformer(unittest.TestCase): + logger.info("Running tests for an imputation model Informer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Informer") + model_save_name = "saved_informer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a Informer model + informer = Informer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + n_heads=2, + d_model=32, + d_ffn=32, + factor=3, + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-informer") + def test_0_fit(self): + self.informer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-informer") + def test_1_impute(self): + imputation_results = self.informer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Informer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-informer") + def test_2_parameters(self): + assert hasattr(self.informer, "model") and self.informer.model is not None + + assert ( + hasattr(self.informer, "optimizer") and self.informer.optimizer is not None + ) + + assert hasattr(self.informer, "best_loss") + self.assertNotEqual(self.informer.best_loss, float("inf")) + + assert ( + hasattr(self.informer, "best_model_dict") + and self.informer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-informer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.informer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.informer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.informer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-informer") + def test_4_lazy_loading(self): + self.informer.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.informer.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading Informer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/imputation/mrnn.py b/tests/imputation/mrnn.py index d1d6003d..4506e755 100644 --- a/tests/imputation/mrnn.py +++ b/tests/imputation/mrnn.py @@ -45,7 +45,7 @@ class TestMRNN(unittest.TestCase): mrnn = MRNN( DATA["n_steps"], DATA["n_features"], - 256, + 32, epochs=EPOCHS, saving_path=saving_path, optimizer=optimizer, diff --git a/tests/imputation/patchtst.py b/tests/imputation/patchtst.py index 6403f709..6a4c0632 100644 --- a/tests/imputation/patchtst.py +++ b/tests/imputation/patchtst.py @@ -46,11 +46,11 @@ class TestPatchTST(unittest.TestCase): DATA["n_steps"], DATA["n_features"], n_layers=2, - d_model=256, - d_ffn=128, - n_heads=4, - d_k=64, - d_v=64, + d_model=64, + d_ffn=32, + n_heads=2, + d_k=16, + d_v=16, patch_len=16, stride=8, dropout=0.1, diff --git a/tests/imputation/saits.py b/tests/imputation/saits.py index 3cfa669c..960e2bd4 100644 --- a/tests/imputation/saits.py +++ b/tests/imputation/saits.py @@ -46,11 +46,11 @@ class TestSAITS(unittest.TestCase): DATA["n_steps"], DATA["n_features"], n_layers=2, - d_model=256, - d_ffn=128, - n_heads=4, - d_k=64, - d_v=64, + d_model=32, + d_ffn=32, + n_heads=2, + d_k=16, + d_v=16, dropout=0.1, epochs=EPOCHS, saving_path=saving_path, diff --git a/tests/imputation/timesnet.py b/tests/imputation/timesnet.py index af35ae08..606d8747 100644 --- a/tests/imputation/timesnet.py +++ b/tests/imputation/timesnet.py @@ -47,8 +47,8 @@ class TestTimesNet(unittest.TestCase): DATA["n_features"], n_layers=2, top_k=3, - d_model=128, - d_ffn=256, + d_model=32, + d_ffn=32, n_kernels=3, dropout=0.1, epochs=EPOCHS, diff --git a/tests/imputation/transformer.py b/tests/imputation/transformer.py index aeeca8bc..2563680c 100644 --- a/tests/imputation/transformer.py +++ b/tests/imputation/transformer.py @@ -46,11 +46,11 @@ class TestTransformer(unittest.TestCase): DATA["n_steps"], DATA["n_features"], n_layers=2, - d_model=256, - d_ffn=128, - n_heads=4, - d_k=64, - d_v=64, + d_model=32, + d_ffn=32, + n_heads=2, + d_k=16, + d_v=16, dropout=0.1, epochs=EPOCHS, saving_path=saving_path, diff --git a/tests/imputation/usgan.py b/tests/imputation/usgan.py index 9ef33139..934553a3 100644 --- a/tests/imputation/usgan.py +++ b/tests/imputation/usgan.py @@ -46,7 +46,7 @@ class TestUSGAN(unittest.TestCase): usgan = USGAN( DATA["n_steps"], DATA["n_features"], - 256, + 32, epochs=EPOCHS, saving_path=saving_path, G_optimizer=G_optimizer,