From 87c0ea5ae02f863d34142c758b2f7cf23ebe3ba9 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 23 Sep 2024 17:27:43 +0800 Subject: [PATCH 1/4] feat: add FITS modules; --- pypots/nn/modules/fits/__init__.py | 24 +++++++++++ pypots/nn/modules/fits/backbone.py | 69 ++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 pypots/nn/modules/fits/__init__.py create mode 100644 pypots/nn/modules/fits/backbone.py diff --git a/pypots/nn/modules/fits/__init__.py b/pypots/nn/modules/fits/__init__.py new file mode 100644 index 00000000..38733b7d --- /dev/null +++ b/pypots/nn/modules/fits/__init__.py @@ -0,0 +1,24 @@ +""" +The package including the modules of FITS. + +Refer to the paper +`Zhijian Xu, Ailing Zeng, and Qiang Xu. +FITS: Modeling Time Series with 10k parameters. +In The Twelfth International Conference on Learning Representations, 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/VEWOXIC/FITS + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from .backbone import BackboneFITS + + +__all__ = [ + "BackboneFITS", +] diff --git a/pypots/nn/modules/fits/backbone.py b/pypots/nn/modules/fits/backbone.py new file mode 100644 index 00000000..272cc800 --- /dev/null +++ b/pypots/nn/modules/fits/backbone.py @@ -0,0 +1,69 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch +import torch.nn as nn + + +class BackboneFITS(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + n_pred_steps: int, + cut_freq: int, + individual: bool, + ): + super().__init__() + self.n_steps = n_steps + self.n_features = n_features + self.n_pred_steps = n_pred_steps + self.individual = individual + + self.dominance_freq = cut_freq + self.length_ratio = (n_steps + n_pred_steps) / n_steps + + if self.individual: + self.freq_upsampler = nn.ModuleList() + for i in range(self.n_features): + self.freq_upsampler.append( + nn.Linear(self.dominance_freq, int(self.dominance_freq * self.length_ratio)).to(torch.cfloat) + ) + else: + # complex layer for frequency upsampling + self.freq_upsampler = nn.Linear(self.dominance_freq, int(self.dominance_freq * self.length_ratio)).to( + torch.cfloat + ) + + def forward(self, x): + low_specx = torch.fft.rfft(x, dim=1) + assert low_specx.size(1) >= self.dominance_freq, ( + f"The sequence length after FFT {low_specx.size(1)} is less than the cut frequency {self.dominance_freq}. " + f"Please check the input sequence length, or decrease the cut frequency." + ) + low_specx[:, self.dominance_freq :] = 0 # LPF + low_specx = low_specx[:, 0 : self.dominance_freq, :] # LPF + + if self.individual: + low_specxy_ = torch.zeros( + [low_specx.size(0), int(self.dominance_freq * self.length_ratio), low_specx.size(2)], + dtype=low_specx.dtype, + ).to(low_specx.device) + for i in range(self.n_features): + low_specxy_[:, :, i] = self.freq_upsampler[i](low_specx[:, :, i].permute(0, 1)).permute(0, 1) + else: + low_specxy_ = self.freq_upsampler(low_specx.permute(0, 2, 1)).permute(0, 2, 1) + + low_specxy = torch.zeros( + [low_specxy_.size(0), int((self.n_steps + self.n_pred_steps) / 2 + 1), low_specxy_.size(2)], + dtype=low_specxy_.dtype, + ).to(low_specxy_.device) + low_specxy[:, 0 : low_specxy_.size(1), :] = low_specxy_ # zero padding + low_xy = torch.fft.irfft(low_specxy, dim=1) + low_xy = low_xy * self.length_ratio # energy compensation for the length change + + return low_xy From 41bd4ddbb2c143d9d77647278cd4de3aec887af1 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 23 Sep 2024 17:29:03 +0800 Subject: [PATCH 2/4] feat: implement FITS as an imputation model; --- pypots/imputation/fits/__init__.py | 24 +++ pypots/imputation/fits/core.py | 86 ++++++++ pypots/imputation/fits/data.py | 24 +++ pypots/imputation/fits/model.py | 308 +++++++++++++++++++++++++++++ 4 files changed, 442 insertions(+) create mode 100644 pypots/imputation/fits/__init__.py create mode 100644 pypots/imputation/fits/core.py create mode 100644 pypots/imputation/fits/data.py create mode 100644 pypots/imputation/fits/model.py diff --git a/pypots/imputation/fits/__init__.py b/pypots/imputation/fits/__init__.py new file mode 100644 index 00000000..23cbae58 --- /dev/null +++ b/pypots/imputation/fits/__init__.py @@ -0,0 +1,24 @@ +""" +The package including the modules of FITS. + +Refer to the paper +`Zhijian Xu, Ailing Zeng, and Qiang Xu. +FITS: Modeling Time Series with 10k parameters. +In The Twelfth International Conference on Learning Representations, 2024. +`_ + +Notes +----- +This implementation is inspired by the official one https://github.com/VEWOXIC/FITS + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import FITS + +__all__ = [ + "FITS", +] diff --git a/pypots/imputation/fits/core.py b/pypots/imputation/fits/core.py new file mode 100644 index 00000000..701ec4ca --- /dev/null +++ b/pypots/imputation/fits/core.py @@ -0,0 +1,86 @@ +""" +The core wrapper assembles the submodules of FITS imputation model +and takes over the forward progress of the algorithm. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.functional import nonstationary_norm, nonstationary_denorm +from ...nn.modules.fits import BackboneFITS +from ...nn.modules.saits import SaitsLoss, SaitsEmbedding + + +class _FITS(nn.Module): + def __init__( + self, + n_steps: int, + n_features: int, + cut_freq: int, + individual: bool, + ORT_weight: float = 1, + MIT_weight: float = 1, + apply_nonstationary_norm: bool = False, + ): + super().__init__() + + self.n_steps = n_steps + self.apply_nonstationary_norm = apply_nonstationary_norm + + self.saits_embedding = SaitsEmbedding( + n_features * 2, + n_features, + with_pos=False, + ) + self.backbone = BackboneFITS( + n_steps, + n_features, + 0, # n_pred_steps is not used in the imputation task + cut_freq, + individual, + ) + + # for the imputation task, the output dim is the same as input dim + self.output_projection = nn.Linear(n_features, n_features) + self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + + if self.apply_nonstationary_norm: + # Normalization from Non-stationary Transformer + X, means, stdev = nonstationary_norm(X, missing_mask) + + # WDU: the original FITS paper isn't proposed for imputation task. Hence the model doesn't take + # the missing mask into account, which means, in the process, the model doesn't know which part of + # the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the + # SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as + # the output layers to project back from the hidden space to the original space. + enc_out = self.saits_embedding(X, missing_mask) + + # FITS encoder processing + enc_out = self.backbone(enc_out) + if self.apply_nonstationary_norm: + # De-Normalization from Non-stationary Transformer + enc_out = nonstationary_denorm(enc_out, means, stdev) + + # project back the original data space + reconstruction = self.output_projection(enc_out) + + imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction + results = { + "imputed_data": imputed_data, + } + + # if in training mode, return results with losses + if training: + X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] + loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + results["loss"] = loss + + return results diff --git a/pypots/imputation/fits/data.py b/pypots/imputation/fits/data.py new file mode 100644 index 00000000..4f7b0b62 --- /dev/null +++ b/pypots/imputation/fits/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for FITS. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForFITS(DatasetForSAITS): + """Actually FITS uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_y: bool, + file_type: str = "hdf5", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_y, file_type, rate) diff --git a/pypots/imputation/fits/model.py b/pypots/imputation/fits/model.py new file mode 100644 index 00000000..2664da26 --- /dev/null +++ b/pypots/imputation/fits/model.py @@ -0,0 +1,308 @@ +""" +The implementation of FITS for the partially-observed time-series imputation task. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .core import _FITS +from .data import DatasetForFITS +from ..base import BaseNNImputer +from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset +from ...optim.adam import Adam +from ...optim.base import Optimizer + + +class FITS(BaseNNImputer): + """The PyTorch implementation of the FITS model. + FITS is originally proposed by Xu et al. in :cite:`xu2024fits`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the FITS model. + + d_model : + The dimension of the model. + + n_heads : + The number of heads in each layer of FITS. + + d_ffn : + The dimension of the feed-forward network. + + factor : + The factor of the auto correlation mechanism for the FITS model. + + moving_avg_window_size : + The window size of moving average. + + dropout : + The dropout rate for the model. + + ORT_weight : + The weight for the ORT loss, the same as SAITS. + + MIT_weight : + The weight for the MIT loss, the same as SAITS. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + """ + + def __init__( + self, + n_steps: int, + n_features: int, + cut_freq: int, + individual: bool = False, + ORT_weight: float = 1, + MIT_weight: float = 1, + apply_nonstationary_norm: bool = False, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + verbose, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.individual = individual + self.cut_freq = cut_freq + self.ORT_weight = ORT_weight + self.MIT_weight = MIT_weight + self.apply_nonstationary_norm = apply_nonstationary_norm + + # set up the model + self.model = _FITS( + self.n_steps, + self.n_features, + self.cut_freq, + self.individual, + self.ORT_weight, + self.MIT_weight, + self.apply_nonstationary_norm, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForFITS(train_set, return_X_ori=False, return_y=False, file_type=file_type) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForFITS(val_set, return_X_ori=True, return_y=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] From 145069abc664e6bb6227252c8e3e423266be66b7 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 23 Sep 2024 17:30:16 +0800 Subject: [PATCH 3/4] test: add test cases for FITS imputation model; --- tests/imputation/fits.py | 117 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/imputation/fits.py diff --git a/tests/imputation/fits.py b/tests/imputation/fits.py new file mode 100644 index 00000000..c35c942e --- /dev/null +++ b/tests/imputation/fits.py @@ -0,0 +1,117 @@ +""" +Test cases for FITS imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation.fits import FITS +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + GENERAL_H5_VAL_SET_PATH, + GENERAL_H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestFITS(unittest.TestCase): + logger.info("Running tests for an imputation model FITS...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "FITS") + model_save_name = "saved_fits_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a FITS model + fits = FITS( + DATA["n_steps"], + DATA["n_features"], + individual=False, + cut_freq=5, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-fits") + def test_0_fit(self): + self.fits.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-fits") + def test_1_impute(self): + imputation_results = self.fits.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"FITS test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-fits") + def test_2_parameters(self): + assert hasattr(self.fits, "model") and self.fits.model is not None + + assert hasattr(self.fits, "optimizer") and self.fits.optimizer is not None + + assert hasattr(self.fits, "best_loss") + self.assertNotEqual(self.fits.best_loss, float("inf")) + + assert hasattr(self.fits, "best_model_dict") and self.fits.best_model_dict is not None + + @pytest.mark.xdist_group(name="imputation-fits") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.fits) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.fits.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.fits.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-fits") + def test_4_lazy_loading(self): + self.fits.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) + imputation_results = self.fits.predict(GENERAL_H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading FITS test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From 903bdab29ffe23fc4afdabd0b3f9dd256929a849 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 27 Sep 2024 00:55:57 +0800 Subject: [PATCH 4/4] Update docs and configs (#530) * docs: add FITS into the algo table; * docs: update the reference of TEFN; * docs: update pytorch intersphinx mapping link; * docs: update docs to add TEFN and FITS imputation models; --- README.md | 8 ++++++-- README_zh.md | 8 ++++++-- docs/conf.py | 2 +- docs/index.rst | 2 ++ docs/pypots.imputation.rst | 18 ++++++++++++++++++ docs/references.bib | 2 +- 6 files changed, 34 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 0556fb10..47ee47fa 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,7 @@ The paper references and links are all listed at the bottom of this file. |:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:-------------------| | LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | | Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` | +| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` | | Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | @@ -501,7 +502,10 @@ Time-Series.AI [^38]: Luo, D., & Wang X. (2024). [ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU). *ICLR 2024*. -[^39]: Zhan, T., He, Y., Li, Z., & Deng, Y. (2024). +[^39]: Zhan, T., He, Y., Deng, Y., Li, Z., Du, W., & Wen, Q. (2024). [Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting](https://arxiv.org/abs/2405.06419). *arXiv 2024*. -[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) \ No newline at end of file +[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) +[^41]: Xu, Z., Zeng, A., & Xu, Q. (2024). +[FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb). +*ICLR 2024*. diff --git a/README_zh.md b/README_zh.md index ef47a408..55978e01 100644 --- a/README_zh.md +++ b/README_zh.md @@ -106,6 +106,7 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异 |:--------------|:---------------------------------------------------------------------------------------------------------------------------------|:------:|:------:|:------:|:------:|:--------:|:-------------------| | LLM | Time-Series.AI [^36] | ✅ | ✅ | ✅ | ✅ | ✅ | `Later in 2024` | | Neural Net | TEFN🧑‍🔧[^39] | ✅ | | | | | `2024 - arXiv` | +| Neural Net | FITS🧑‍🔧[^41] | ✅ | | | | | `2024 - ICLR` | | Neural Net | TimeMixer[^37] | ✅ | | | | | `2024 - ICLR` | | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | @@ -474,7 +475,10 @@ Time-Series.AI [^38]: Luo, D., & Wang X. (2024). [ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis](https://openreview.net/forum?id=vpJMJerXHU). *ICLR 2024*. -[^39]: Zhan, T., He, Y., Li, Z., & Deng, Y. (2024). +[^39]: Zhan, T., He, Y., Deng, Y., Li, Z., Du, W., & Wen, Q. (2024). [Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting](https://arxiv.org/abs/2405.06419). *arXiv 2024*. -[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) \ No newline at end of file +[^40]: [Wikipedia: Linear interpolation](https://en.wikipedia.org/wiki/Linear_interpolation) +[^41]: Xu, Z., Zeng, A., & Xu, Q. (2024). +[FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb). +*ICLR 2024*. diff --git a/docs/conf.py b/docs/conf.py index 5b8e3822..f21afad7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,7 +62,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "sphinx": ("https://www.sphinx-doc.org/en/master", None), - "torch": ("https://pytorch.org/docs/master/", None), + "torch": ("https://pytorch.org/docs/main/", None), "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), } diff --git a/docs/index.rst b/docs/index.rst index 1d1fe36f..4aaeb762 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -135,6 +135,8 @@ The paper references are all listed at the bottom of this readme file. +================+===========================================================+======+======+======+======+======+=======================+ | Neural Net | TEFN🧑‍🔧 :cite:`zhan2024tefn` | ✅ | | | | | ``2024 - arXiv`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| Neural Net | FITS🧑‍🔧 :cite:`xu2024fits` | ✅ | | | | | ``2024 - ICLR`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | TimeMixer :cite:`wang2024timemixer` | ✅ | | | | | ``2024 - ICLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | iTransformer🧑‍🔧 :cite:`liu2024itransformer` | ✅ | | | | | ``2024 - ICLR`` | diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index b7a94b03..a7b47b07 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -19,6 +19,24 @@ pypots.imputation.transformer :show-inheritance: :inherited-members: +pypots.imputation.tefn +------------------------------------ + +.. automodule:: pypots.imputation.tefn + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +pypots.imputation.fits +------------------------------------ + +.. automodule:: pypots.imputation.fits + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.timemixer ------------------------------------ diff --git a/docs/references.bib b/docs/references.bib index 171d63c5..ce0014ea 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -766,7 +766,7 @@ @article{bai2018tcn @article{zhan2024tefn, title={Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting}, -author={Zhan, Tianxiang and He, Yuanpeng and Li, Zhen and Deng, Yong}, +author={Zhan, Tianxiang and He, Yuanpeng and Deng, Yong and Li, Zhen and Du, Wenjie and Wen, Qingsong}, journal={arXiv preprint arXiv:2405.06419}, year={2024} }