Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Pyraformer as an imputation model #389

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .transformer import Transformer
from .itransformer import iTransformer
from .nonstationary_transformer import NonstationaryTransformer
from .pyraformer import Pyraformer
from .timesnet import TimesNet
from .etsformer import ETSformer
from .fedformer import FEDformer
Expand Down Expand Up @@ -47,6 +48,7 @@
"Informer",
"Autoformer",
"NonstationaryTransformer",
"Pyraformer",
"BRITS",
"MRNN",
"GPVAE",
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/nonstationary_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class NonstationaryTransformer(BaseNNImputer):
"""The PyTorch implementation of the Nonstationary-Transformer model.
NonstationaryTransformer is originally proposed by Wu et al. in :cite:`liu2022nonstationary`.
NonstationaryTransformer is originally proposed by Liu et al. in :cite:`liu2022nonstationary`.

Parameters
----------
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/pyraformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model Pyraformer.

Refer to the paper
`Shizhan Liu, Hang Yu, Cong Liao, Jianguo Li, Weiyao Lin, Alex X. Liu, and Schahram Dustdar.
"Pyraformer: Low-Complexity Pyramidal Attention for Long-Range Time Series Modeling and Forecasting".
International Conference on Learning Representations. 2022.
<https://openreview.net/pdf?id=0EXmFzUn5I>`_

Notes
-----
This implementation is inspired by the official one https://github.com/ant-research/Pyraformer

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import Pyraformer

__all__ = [
"Pyraformer",
]
86 changes: 86 additions & 0 deletions pypots/imputation/pyraformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
The core wrapper assembles the submodules of Pyraformer imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.modules.pyraformer import PyraformerEncoder
from ...nn.modules.saits import SaitsLoss, SaitsEmbedding


class _Pyraformer(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_layers: int,
d_model: int,
n_heads: int,
d_ffn: int,
dropout: float,
attn_dropout: float,
window_size: list,
inner_size: int,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.saits_embedding = SaitsEmbedding(
n_features * 2,
d_model,
with_pos=False,
dropout=dropout,
)
self.encoder = PyraformerEncoder(
n_steps,
n_layers,
d_model,
n_heads,
d_ffn,
dropout,
attn_dropout,
window_size,
inner_size,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear((len(window_size) + 1) * d_model, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original Pyraformer paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
enc_out = self.saits_embedding(X, missing_mask)

# Pyraformer encoder processing
enc_out, attns = self.encoder(enc_out)
# project back the original data space
reconstruction = self.output_projection(enc_out)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/pyraformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for Pyraformer.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForPyraformer(DatasetForSAITS):
"""Actually Pyraformer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading
Loading