diff --git a/README.md b/README.md
index 0556fb10..47ee47fa 100644
--- a/README.md
+++ b/README.md
@@ -121,6 +121,7 @@ The paper references and links are all listed at the bottom of this file.
|:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------------|
| LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` |
| Neural Net | TEFN🧑🔧[^39] | ✅ | | | | | `2024 - arXiv` |
+| Neural Net | FITS🧑🔧[^41] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
@@ -501,7 +502,10 @@ Time-Series.AI
[^38]: Luo, D., & Wang X. (2024).
[ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU).
*ICLR 2024*.
-[^39]: Zhan, T., He, Y., Li, Z., & Deng, Y. (2024).
+[^39]: Zhan, T., He, Y., Deng, Y., Li, Z., Du, W., & Wen, Q. (2024).
[Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting](https://arxiv.org/abs/2405.06419).
*arXiv 2024*.
-[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation)
\ No newline at end of file
+[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation)
+[^41]: Xu, Z., Zeng, A., & Xu, Q. (2024).
+[FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb).
+*ICLR 2024*.
diff --git a/README_zh.md b/README_zh.md
index ef47a408..55978e01 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -106,6 +106,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异
|:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:------:|:------:|:------:|:------:|:--------:|:-------------------|
| LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` |
| Neural Net | TEFN🧑🔧[^39] | ✅ | | | | | `2024 - arXiv` |
+| Neural Net | FITS🧑🔧[^41] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | iTransformer🧑🔧[^24] | ✅ | | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` |
@@ -474,7 +475,10 @@ Time-Series.AI
[^38]: Luo, D., & Wang X. (2024).
[ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU).
*ICLR 2024*.
-[^39]: Zhan, T., He, Y., Li, Z., & Deng, Y. (2024).
+[^39]: Zhan, T., He, Y., Deng, Y., Li, Z., Du, W., & Wen, Q. (2024).
[Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting](https://arxiv.org/abs/2405.06419).
*arXiv 2024*.
-[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation)
\ No newline at end of file
+[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation)
+[^41]: Xu, Z., Zeng, A., & Xu, Q. (2024).
+[FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb).
+*ICLR 2024*.
diff --git a/docs/conf.py b/docs/conf.py
index 5b8e3822..f21afad7 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -62,7 +62,7 @@
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"sphinx": ("https://www.sphinx-doc.org/en/master", None),
- "torch": ("https://pytorch.org/docs/master/", None),
+ "torch": ("https://pytorch.org/docs/main/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
}
diff --git a/docs/index.rst b/docs/index.rst
index 1d1fe36f..4aaeb762 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -135,6 +135,8 @@ The paper references are all listed at the bottom of this readme file.
+================+===========================================================+======+======+======+======+======+=======================+
| Neural Net | TEFN🧑🔧 :cite:`zhan2024tefn` | ✅ | | | | | ``2024 - arXiv`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
+| Neural Net | FITS🧑🔧 :cite:`xu2024fits` | ✅ | | | | | ``2024 - ICLR`` |
++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` |
+----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+
| Neural Net | iTransformer🧑🔧 :cite:`liu2024itransformer` | ✅ | | | | | ``2024 - ICLR`` |
diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst
index b7a94b03..a7b47b07 100644
--- a/docs/pypots.imputation.rst
+++ b/docs/pypots.imputation.rst
@@ -19,6 +19,24 @@ pypots.imputation.transformer
:show-inheritance:
:inherited-members:
+pypots.imputation.tefn
+------------------------------------
+
+.. automodule:: pypots.imputation.tefn
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+
+pypots.imputation.fits
+------------------------------------
+
+.. automodule:: pypots.imputation.fits
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+
pypots.imputation.timemixer
------------------------------------
diff --git a/docs/references.bib b/docs/references.bib
index 171d63c5..ce0014ea 100644
--- a/docs/references.bib
+++ b/docs/references.bib
@@ -766,7 +766,7 @@ @article{bai2018tcn
@article{zhan2024tefn,
title={Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting},
-author={Zhan, Tianxiang and He, Yuanpeng and Li, Zhen and Deng, Yong},
+author={Zhan, Tianxiang and He, Yuanpeng and Deng, Yong and Li, Zhen and Du, Wenjie and Wen, Qingsong},
journal={arXiv preprint arXiv:2405.06419},
year={2024}
}
diff --git a/pypots/imputation/fits/__init__.py b/pypots/imputation/fits/__init__.py
new file mode 100644
index 00000000..23cbae58
--- /dev/null
+++ b/pypots/imputation/fits/__init__.py
@@ -0,0 +1,24 @@
+"""
+The package including the modules of FITS.
+
+Refer to the paper
+`Zhijian Xu, Ailing Zeng, and Qiang Xu.
+FITS: Modeling Time Series with 10k parameters.
+In The Twelfth International Conference on Learning Representations, 2024.
+`_
+
+Notes
+-----
+This implementation is inspired by the official one https://github.com/VEWOXIC/FITS
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+
+from .model import FITS
+
+__all__ = [
+ "FITS",
+]
diff --git a/pypots/imputation/fits/core.py b/pypots/imputation/fits/core.py
new file mode 100644
index 00000000..701ec4ca
--- /dev/null
+++ b/pypots/imputation/fits/core.py
@@ -0,0 +1,86 @@
+"""
+The core wrapper assembles the submodules of FITS imputation model
+and takes over the forward progress of the algorithm.
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+import torch.nn as nn
+
+from ...nn.functional import nonstationary_norm, nonstationary_denorm
+from ...nn.modules.fits import BackboneFITS
+from ...nn.modules.saits import SaitsLoss, SaitsEmbedding
+
+
+class _FITS(nn.Module):
+ def __init__(
+ self,
+ n_steps: int,
+ n_features: int,
+ cut_freq: int,
+ individual: bool,
+ ORT_weight: float = 1,
+ MIT_weight: float = 1,
+ apply_nonstationary_norm: bool = False,
+ ):
+ super().__init__()
+
+ self.n_steps = n_steps
+ self.apply_nonstationary_norm = apply_nonstationary_norm
+
+ self.saits_embedding = SaitsEmbedding(
+ n_features * 2,
+ n_features,
+ with_pos=False,
+ )
+ self.backbone = BackboneFITS(
+ n_steps,
+ n_features,
+ 0, # n_pred_steps is not used in the imputation task
+ cut_freq,
+ individual,
+ )
+
+ # for the imputation task, the output dim is the same as input dim
+ self.output_projection = nn.Linear(n_features, n_features)
+ self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)
+
+ def forward(self, inputs: dict, training: bool = True) -> dict:
+ X, missing_mask = inputs["X"], inputs["missing_mask"]
+
+ if self.apply_nonstationary_norm:
+ # Normalization from Non-stationary Transformer
+ X, means, stdev = nonstationary_norm(X, missing_mask)
+
+ # WDU: the original FITS paper isn't proposed for imputation task. Hence the model doesn't take
+ # the missing mask into account, which means, in the process, the model doesn't know which part of
+ # the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
+ # SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
+ # the output layers to project back from the hidden space to the original space.
+ enc_out = self.saits_embedding(X, missing_mask)
+
+ # FITS encoder processing
+ enc_out = self.backbone(enc_out)
+ if self.apply_nonstationary_norm:
+ # De-Normalization from Non-stationary Transformer
+ enc_out = nonstationary_denorm(enc_out, means, stdev)
+
+ # project back the original data space
+ reconstruction = self.output_projection(enc_out)
+
+ imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
+ results = {
+ "imputed_data": imputed_data,
+ }
+
+ # if in training mode, return results with losses
+ if training:
+ X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
+ loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
+ results["ORT_loss"] = ORT_loss
+ results["MIT_loss"] = MIT_loss
+ # `loss` is always the item for backward propagating to update the model
+ results["loss"] = loss
+
+ return results
diff --git a/pypots/imputation/fits/data.py b/pypots/imputation/fits/data.py
new file mode 100644
index 00000000..4f7b0b62
--- /dev/null
+++ b/pypots/imputation/fits/data.py
@@ -0,0 +1,24 @@
+"""
+Dataset class for FITS.
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Union
+
+from ..saits.data import DatasetForSAITS
+
+
+class DatasetForFITS(DatasetForSAITS):
+ """Actually FITS uses the same data strategy as SAITS, needs MIT for training."""
+
+ def __init__(
+ self,
+ data: Union[dict, str],
+ return_X_ori: bool,
+ return_y: bool,
+ file_type: str = "hdf5",
+ rate: float = 0.2,
+ ):
+ super().__init__(data, return_X_ori, return_y, file_type, rate)
diff --git a/pypots/imputation/fits/model.py b/pypots/imputation/fits/model.py
new file mode 100644
index 00000000..2664da26
--- /dev/null
+++ b/pypots/imputation/fits/model.py
@@ -0,0 +1,308 @@
+"""
+The implementation of FITS for the partially-observed time-series imputation task.
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Union, Optional
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from .core import _FITS
+from .data import DatasetForFITS
+from ..base import BaseNNImputer
+from ...data.checking import key_in_data_set
+from ...data.dataset import BaseDataset
+from ...optim.adam import Adam
+from ...optim.base import Optimizer
+
+
+class FITS(BaseNNImputer):
+ """The PyTorch implementation of the FITS model.
+ FITS is originally proposed by Xu et al. in :cite:`xu2024fits`.
+
+ Parameters
+ ----------
+ n_steps :
+ The number of time steps in the time-series data sample.
+
+ n_features :
+ The number of features in the time-series data sample.
+
+ n_layers :
+ The number of layers in the FITS model.
+
+ d_model :
+ The dimension of the model.
+
+ n_heads :
+ The number of heads in each layer of FITS.
+
+ d_ffn :
+ The dimension of the feed-forward network.
+
+ factor :
+ The factor of the auto correlation mechanism for the FITS model.
+
+ moving_avg_window_size :
+ The window size of moving average.
+
+ dropout :
+ The dropout rate for the model.
+
+ ORT_weight :
+ The weight for the ORT loss, the same as SAITS.
+
+ MIT_weight :
+ The weight for the MIT loss, the same as SAITS.
+
+ batch_size :
+ The batch size for training and evaluating the model.
+
+ epochs :
+ The number of epochs for training the model.
+
+ patience :
+ The patience for the early-stopping mechanism. Given a positive integer, the training process will be
+ stopped when the model does not perform better after that number of epochs.
+ Leaving it default as None will disable the early-stopping.
+
+ optimizer :
+ The optimizer for model training.
+ If not given, will use a default Adam optimizer.
+
+ num_workers :
+ The number of subprocesses to use for data loading.
+ `0` means data loading will be in the main process, i.e. there won't be subprocesses.
+
+ device :
+ The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
+ If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
+ then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
+ If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
+ model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
+ Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
+
+ saving_path :
+ The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
+ training into a tensorboard file). Will not save if not given.
+
+ model_saving_strategy :
+ The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"].
+ No model will be saved when it is set as None.
+ The "best" strategy will only automatically save the best model after the training finished.
+ The "better" strategy will automatically save the model during training whenever the model performs
+ better than in previous epochs.
+ The "all" strategy will save every model after each epoch training.
+
+ verbose :
+ Whether to print out the training logs during the training process.
+ """
+
+ def __init__(
+ self,
+ n_steps: int,
+ n_features: int,
+ cut_freq: int,
+ individual: bool = False,
+ ORT_weight: float = 1,
+ MIT_weight: float = 1,
+ apply_nonstationary_norm: bool = False,
+ batch_size: int = 32,
+ epochs: int = 100,
+ patience: int = None,
+ optimizer: Optional[Optimizer] = Adam(),
+ num_workers: int = 0,
+ device: Optional[Union[str, torch.device, list]] = None,
+ saving_path: str = None,
+ model_saving_strategy: Optional[str] = "best",
+ verbose: bool = True,
+ ):
+ super().__init__(
+ batch_size,
+ epochs,
+ patience,
+ num_workers,
+ device,
+ saving_path,
+ model_saving_strategy,
+ verbose,
+ )
+
+ self.n_steps = n_steps
+ self.n_features = n_features
+ # model hype-parameters
+ self.individual = individual
+ self.cut_freq = cut_freq
+ self.ORT_weight = ORT_weight
+ self.MIT_weight = MIT_weight
+ self.apply_nonstationary_norm = apply_nonstationary_norm
+
+ # set up the model
+ self.model = _FITS(
+ self.n_steps,
+ self.n_features,
+ self.cut_freq,
+ self.individual,
+ self.ORT_weight,
+ self.MIT_weight,
+ self.apply_nonstationary_norm,
+ )
+ self._send_model_to_given_device()
+ self._print_model_size()
+
+ # set up the optimizer
+ self.optimizer = optimizer
+ self.optimizer.init_optimizer(self.model.parameters())
+
+ def _assemble_input_for_training(self, data: list) -> dict:
+ (
+ indices,
+ X,
+ missing_mask,
+ X_ori,
+ indicating_mask,
+ ) = self._send_data_to_given_device(data)
+
+ inputs = {
+ "X": X,
+ "missing_mask": missing_mask,
+ "X_ori": X_ori,
+ "indicating_mask": indicating_mask,
+ }
+
+ return inputs
+
+ def _assemble_input_for_validating(self, data: list) -> dict:
+ return self._assemble_input_for_training(data)
+
+ def _assemble_input_for_testing(self, data: list) -> dict:
+ indices, X, missing_mask = self._send_data_to_given_device(data)
+
+ inputs = {
+ "X": X,
+ "missing_mask": missing_mask,
+ }
+
+ return inputs
+
+ def fit(
+ self,
+ train_set: Union[dict, str],
+ val_set: Optional[Union[dict, str]] = None,
+ file_type: str = "hdf5",
+ ) -> None:
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ training_set = DatasetForFITS(train_set, return_X_ori=False, return_y=False, file_type=file_type)
+ training_loader = DataLoader(
+ training_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ )
+ val_loader = None
+ if val_set is not None:
+ if not key_in_data_set("X_ori", val_set):
+ raise ValueError("val_set must contain 'X_ori' for model validation.")
+ val_set = DatasetForFITS(val_set, return_X_ori=True, return_y=False, file_type=file_type)
+ val_loader = DataLoader(
+ val_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+
+ # Step 2: train the model and freeze it
+ self._train_model(training_loader, val_loader)
+ self.model.load_state_dict(self.best_model_dict)
+ self.model.eval() # set the model as eval status to freeze it.
+
+ # Step 3: save the model if necessary
+ self._auto_save_model_if_necessary(confirm_saving=True)
+
+ def predict(
+ self,
+ test_set: Union[dict, str],
+ file_type: str = "hdf5",
+ ) -> dict:
+ """Make predictions for the input data with the trained model.
+
+ Parameters
+ ----------
+ test_set : dict or str
+ The dataset for model validating, should be a dictionary including keys as 'X',
+ or a path string locating a data file supported by PyPOTS (e.g. h5 file).
+ If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features],
+ which is time-series data for validating, can contain missing values, and y should be array-like of shape
+ [n_samples], which is classification labels of X.
+ If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
+ key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
+
+ file_type :
+ The type of the given file if test_set is a path string.
+
+ Returns
+ -------
+ file_type :
+ The dictionary containing the clustering results and latent variables if necessary.
+
+ """
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ self.model.eval() # set the model as eval status to freeze it.
+ test_set = BaseDataset(
+ test_set,
+ return_X_ori=False,
+ return_X_pred=False,
+ return_y=False,
+ file_type=file_type,
+ )
+ test_loader = DataLoader(
+ test_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+ imputation_collector = []
+
+ # Step 2: process the data with the model
+ with torch.no_grad():
+ for idx, data in enumerate(test_loader):
+ inputs = self._assemble_input_for_testing(data)
+ results = self.model.forward(inputs, training=False)
+ imputation_collector.append(results["imputed_data"])
+
+ # Step 3: output collection and return
+ imputation = torch.cat(imputation_collector).cpu().detach().numpy()
+ result_dict = {
+ "imputation": imputation,
+ }
+ return result_dict
+
+ def impute(
+ self,
+ test_set: Union[dict, str],
+ file_type: str = "hdf5",
+ ) -> np.ndarray:
+ """Impute missing values in the given data with the trained model.
+
+ Parameters
+ ----------
+ test_set :
+ The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps),
+ n_features], or a path string locating a data file, e.g. h5 file.
+
+ file_type :
+ The type of the given file if X is a path string.
+
+ Returns
+ -------
+ array-like, shape [n_samples, sequence length (n_steps), n_features],
+ Imputed data.
+ """
+
+ result_dict = self.predict(test_set, file_type=file_type)
+ return result_dict["imputation"]
diff --git a/pypots/nn/modules/fits/__init__.py b/pypots/nn/modules/fits/__init__.py
new file mode 100644
index 00000000..38733b7d
--- /dev/null
+++ b/pypots/nn/modules/fits/__init__.py
@@ -0,0 +1,24 @@
+"""
+The package including the modules of FITS.
+
+Refer to the paper
+`Zhijian Xu, Ailing Zeng, and Qiang Xu.
+FITS: Modeling Time Series with 10k parameters.
+In The Twelfth International Conference on Learning Representations, 2024.
+`_
+
+Notes
+-----
+This implementation is inspired by the official one https://github.com/VEWOXIC/FITS
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from .backbone import BackboneFITS
+
+
+__all__ = [
+ "BackboneFITS",
+]
diff --git a/pypots/nn/modules/fits/backbone.py b/pypots/nn/modules/fits/backbone.py
new file mode 100644
index 00000000..272cc800
--- /dev/null
+++ b/pypots/nn/modules/fits/backbone.py
@@ -0,0 +1,69 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+import torch
+import torch.nn as nn
+
+
+class BackboneFITS(nn.Module):
+ def __init__(
+ self,
+ n_steps: int,
+ n_features: int,
+ n_pred_steps: int,
+ cut_freq: int,
+ individual: bool,
+ ):
+ super().__init__()
+ self.n_steps = n_steps
+ self.n_features = n_features
+ self.n_pred_steps = n_pred_steps
+ self.individual = individual
+
+ self.dominance_freq = cut_freq
+ self.length_ratio = (n_steps + n_pred_steps) / n_steps
+
+ if self.individual:
+ self.freq_upsampler = nn.ModuleList()
+ for i in range(self.n_features):
+ self.freq_upsampler.append(
+ nn.Linear(self.dominance_freq, int(self.dominance_freq * self.length_ratio)).to(torch.cfloat)
+ )
+ else:
+ # complex layer for frequency upsampling
+ self.freq_upsampler = nn.Linear(self.dominance_freq, int(self.dominance_freq * self.length_ratio)).to(
+ torch.cfloat
+ )
+
+ def forward(self, x):
+ low_specx = torch.fft.rfft(x, dim=1)
+ assert low_specx.size(1) >= self.dominance_freq, (
+ f"The sequence length after FFT {low_specx.size(1)} is less than the cut frequency {self.dominance_freq}. "
+ f"Please check the input sequence length, or decrease the cut frequency."
+ )
+ low_specx[:, self.dominance_freq :] = 0 # LPF
+ low_specx = low_specx[:, 0 : self.dominance_freq, :] # LPF
+
+ if self.individual:
+ low_specxy_ = torch.zeros(
+ [low_specx.size(0), int(self.dominance_freq * self.length_ratio), low_specx.size(2)],
+ dtype=low_specx.dtype,
+ ).to(low_specx.device)
+ for i in range(self.n_features):
+ low_specxy_[:, :, i] = self.freq_upsampler[i](low_specx[:, :, i].permute(0, 1)).permute(0, 1)
+ else:
+ low_specxy_ = self.freq_upsampler(low_specx.permute(0, 2, 1)).permute(0, 2, 1)
+
+ low_specxy = torch.zeros(
+ [low_specxy_.size(0), int((self.n_steps + self.n_pred_steps) / 2 + 1), low_specxy_.size(2)],
+ dtype=low_specxy_.dtype,
+ ).to(low_specxy_.device)
+ low_specxy[:, 0 : low_specxy_.size(1), :] = low_specxy_ # zero padding
+ low_xy = torch.fft.irfft(low_specxy, dim=1)
+ low_xy = low_xy * self.length_ratio # energy compensation for the length change
+
+ return low_xy
diff --git a/tests/imputation/fits.py b/tests/imputation/fits.py
new file mode 100644
index 00000000..c35c942e
--- /dev/null
+++ b/tests/imputation/fits.py
@@ -0,0 +1,117 @@
+"""
+Test cases for FITS imputation model.
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+
+import os.path
+import unittest
+
+import numpy as np
+import pytest
+
+from pypots.imputation.fits import FITS
+from pypots.optim import Adam
+from pypots.utils.logging import logger
+from pypots.utils.metrics import calc_mse
+from tests.global_test_config import (
+ DATA,
+ EPOCHS,
+ DEVICE,
+ TRAIN_SET,
+ VAL_SET,
+ TEST_SET,
+ GENERAL_H5_TRAIN_SET_PATH,
+ GENERAL_H5_VAL_SET_PATH,
+ GENERAL_H5_TEST_SET_PATH,
+ RESULT_SAVING_DIR_FOR_IMPUTATION,
+ check_tb_and_model_checkpoints_existence,
+)
+
+
+class TestFITS(unittest.TestCase):
+ logger.info("Running tests for an imputation model FITS...")
+
+ # set the log and model saving path
+ saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "FITS")
+ model_save_name = "saved_fits_model.pypots"
+
+ # initialize an Adam optimizer
+ optimizer = Adam(lr=0.001, weight_decay=1e-5)
+
+ # initialize a FITS model
+ fits = FITS(
+ DATA["n_steps"],
+ DATA["n_features"],
+ individual=False,
+ cut_freq=5,
+ epochs=EPOCHS,
+ saving_path=saving_path,
+ optimizer=optimizer,
+ device=DEVICE,
+ )
+
+ @pytest.mark.xdist_group(name="imputation-fits")
+ def test_0_fit(self):
+ self.fits.fit(TRAIN_SET, VAL_SET)
+
+ @pytest.mark.xdist_group(name="imputation-fits")
+ def test_1_impute(self):
+ imputation_results = self.fits.predict(TEST_SET)
+ assert not np.isnan(
+ imputation_results["imputation"]
+ ).any(), "Output still has missing values after running impute()."
+
+ test_MSE = calc_mse(
+ imputation_results["imputation"],
+ DATA["test_X_ori"],
+ DATA["test_X_indicating_mask"],
+ )
+ logger.info(f"FITS test_MSE: {test_MSE}")
+
+ @pytest.mark.xdist_group(name="imputation-fits")
+ def test_2_parameters(self):
+ assert hasattr(self.fits, "model") and self.fits.model is not None
+
+ assert hasattr(self.fits, "optimizer") and self.fits.optimizer is not None
+
+ assert hasattr(self.fits, "best_loss")
+ self.assertNotEqual(self.fits.best_loss, float("inf"))
+
+ assert hasattr(self.fits, "best_model_dict") and self.fits.best_model_dict is not None
+
+ @pytest.mark.xdist_group(name="imputation-fits")
+ def test_3_saving_path(self):
+ # whether the root saving dir exists, which should be created by save_log_into_tb_file
+ assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist"
+
+ # check if the tensorboard file and model checkpoints exist
+ check_tb_and_model_checkpoints_existence(self.fits)
+
+ # save the trained model into file, and check if the path exists
+ saved_model_path = os.path.join(self.saving_path, self.model_save_name)
+ self.fits.save(saved_model_path)
+
+ # test loading the saved model, not necessary, but need to test
+ self.fits.load(saved_model_path)
+
+ @pytest.mark.xdist_group(name="imputation-fits")
+ def test_4_lazy_loading(self):
+ self.fits.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
+ imputation_results = self.fits.predict(GENERAL_H5_TEST_SET_PATH)
+ assert not np.isnan(
+ imputation_results["imputation"]
+ ).any(), "Output still has missing values after running impute()."
+
+ test_MSE = calc_mse(
+ imputation_results["imputation"],
+ DATA["test_X_ori"],
+ DATA["test_X_indicating_mask"],
+ )
+ logger.info(f"Lazy-loading FITS test_MSE: {test_MSE}")
+
+
+if __name__ == "__main__":
+ unittest.main()