Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing unstable nonstationary norm, adding utils.visual, and doing some code refactoring #266

Merged
merged 7 commits into from
Dec 14, 2023
7 changes: 7 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,10 @@ @inproceedings{wu2023timesnet
year={2023},
url={https://openreview.net/forum?id=ju_Uqw384Oq}
}

@inproceedings{liu2022nonstationary,
title={Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting},
author={Liu, Yong and Wu, Haixu and Wang, Jianmin and Long, Mingsheng},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
2 changes: 1 addition & 1 deletion pypots/classification/grud/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ....modules.rnn import TemporalDecay
from ....nn.modules.rnn import TemporalDecay


class _GRUD(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.nn as nn

from .submodules import FeatureRegression
from ....modules.rnn import TemporalDecay
from ....nn.modules.rnn import TemporalDecay
from ....utils.metrics import calc_mae


Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/saits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ....modules.transformer import EncoderLayer, PositionalEncoding
from ....nn.modules.transformer import EncoderLayer, PositionalEncoding
from ....utils.metrics import calc_mae


Expand Down
10 changes: 9 additions & 1 deletion pypots/imputation/timesnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from torch.utils.data import DataLoader

from .data import DatasetForTimesNet
from ...utils.logging import logger
from .modules.core import _TimesNet
from ..base import BaseNNImputer
from ...data.base import BaseDataset
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class TimesNet(BaseNNImputer):
Expand Down Expand Up @@ -59,6 +59,11 @@ class TimesNet(BaseNNImputer):
dropout :
The dropout rate for the model.

apply_nonstationary_norm :
Whether to apply non-stationary normalization to the input data for TimesNet.
Please refer to :cite:`liu2022nonstationary` for details about non-stationary normalization,
which is not the idea of the original TimesNet paper. Hence, we make it optional and default not to use here.

batch_size :
The batch size for training and evaluating the model.

Expand Down Expand Up @@ -117,6 +122,7 @@ def __init__(
d_ffn: int,
n_kernels: int,
dropout: float = 0,
apply_nonstationary_norm: bool = False,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
Expand Down Expand Up @@ -145,6 +151,7 @@ def __init__(
self.d_ffn = d_ffn
self.n_kernels = n_kernels
self.dropout = dropout
self.apply_nonstationary_norm = apply_nonstationary_norm

# set up the model
self.model = _TimesNet(
Expand All @@ -156,6 +163,7 @@ def __init__(
self.d_ffn,
self.n_kernels,
self.dropout,
self.apply_nonstationary_norm,
)
self._send_model_to_given_device()
self._print_model_size()
Expand Down
31 changes: 10 additions & 21 deletions pypots/imputation/timesnet/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import torch
import torch.fft
import torch.nn as nn

from .embedding import DataEmbedding
from .layer import TimesBlock
from ....nn.functional import nonstationary_norm, nonstationary_denorm
from ....utils.metrics import calc_mse


Expand All @@ -25,11 +24,13 @@ def __init__(
d_ffn,
n_kernels,
dropout,
apply_nonstationary_norm,
):
super().__init__()

self.seq_len = n_steps
self.n_layers = n_layers
self.apply_nonstationary_norm = apply_nonstationary_norm

self.pred_len = 0 # for the imputation task, the pred_len is always 0
self.model = nn.ModuleList(
Expand All @@ -52,36 +53,24 @@ def __init__(
def forward(self, inputs: dict, training: bool = True) -> dict:
X, masks = inputs["X"], inputs["missing_mask"]

# Normalization from Non-stationary Transformer
means = torch.sum(X, dim=1) / torch.sum(masks == 1, dim=1)
means = means.unsqueeze(1).detach()
x_enc = X - means
x_enc = x_enc.masked_fill(masks == 0, 0)
stdev = torch.sqrt(
torch.sum(x_enc * x_enc, dim=1) / torch.sum(masks == 1, dim=1) + 1e-5
)
stdev = stdev.unsqueeze(1).detach()
x_enc /= stdev
if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, masks)

# embedding
enc_out = self.enc_embedding(x_enc) # [B,T,C]
enc_out = self.enc_embedding(X) # [B,T,C]
# TimesNet
for i in range(self.n_layers):
enc_out = self.layer_norm(self.model[i](enc_out))

# project back the original data space
dec_out = self.projection(enc_out)

# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (
stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1)
)
dec_out = dec_out + (
means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1)
)
if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
dec_out = nonstationary_denorm(dec_out, means, stdev)

imputed_data = masks * X + (1 - masks) * dec_out

results = {
"imputed_data": imputed_data,
}
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/timesnet/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.fft
import torch.nn as nn

from ....modules.transformer import PositionalEncoding
from ....nn.modules.transformer import PositionalEncoding


class TokenEmbedding(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/transformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn

from ....modules.transformer import EncoderLayer, PositionalEncoding
from ....nn.modules.transformer import EncoderLayer, PositionalEncoding
from ....utils.metrics import calc_mae


Expand Down
10 changes: 9 additions & 1 deletion pypots/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""
Frequently-used modules like self-attention modules of vanilla Transformer are put in this package.
Everything used to be in this package has been moved to pypots.nn.modules.
This package is kept for backward compatibility and will be removed in the future.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from ..utils.logging import logger

logger.warning(
"🚨 pypots.modules package has been moved to pypots.nn.modules. "
"Please import everything from pypots.nn.modules instead."
)
6 changes: 6 additions & 0 deletions pypots/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
14 changes: 14 additions & 0 deletions pypots/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""

"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from .normalization import nonstationary_norm, nonstationary_denorm

__all__ = [
# normalization functions
"nonstationary_norm",
"nonstationary_denorm",
]
98 changes: 98 additions & 0 deletions pypots/nn/functional/normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Store normalization functions for neural networks.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from typing import Tuple, Optional

import torch


def nonstationary_norm(
X: torch.Tensor,
missing_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Normalization from Non-stationary Transformer. Please refer to :cite:`liu2022nonstationary` for more details.

Parameters
----------
X : torch.Tensor
Input data to be normalized. Shape: (n_samples, n_steps (seq_len), n_features).

missing_mask : torch.Tensor, optional
Missing mask has the same shape as X. 1 indicates observed and 0 indicates missing.

Returns
-------
X_enc : torch.Tensor
Normalized data. Shape: (n_samples, n_steps (seq_len), n_features).

means : torch.Tensor
Means values for de-normalization. Shape: (n_samples, n_features) or (n_samples, 1, n_features).

stdev : torch.Tensor
Standard deviation values for de-normalization. Shape: (n_samples, n_features) or (n_samples, 1, n_features).

"""
if torch.isnan(X).any():
if missing_mask is None:
missing_mask = torch.isnan(X)
else:
raise ValueError("missing_mask is given but X still contains nan values.")

if missing_mask is None:
means = X.mean(1, keepdim=True).detach()
X_enc = X - means
variance = torch.var(X_enc, dim=1, keepdim=True, unbiased=False) + 1e-9
stdev = torch.sqrt(variance).detach()
else:
# for data contain missing values, add a small number to avoid dividing by 0
missing_sum = torch.sum(missing_mask == 1, dim=1, keepdim=True) + 1e-9
means = torch.sum(X, dim=1, keepdim=True) / missing_sum
X_enc = X - means
X_enc = X_enc.masked_fill(missing_mask == 0, 0)
variance = torch.sum(X_enc * X_enc, dim=1, keepdim=True) + 1e-9
stdev = torch.sqrt(variance / missing_sum)

X_enc /= stdev
return X_enc, means, stdev


def nonstationary_denorm(
X: torch.Tensor,
means: torch.Tensor,
stdev: torch.Tensor,
) -> torch.Tensor:
"""De-Normalization from Non-stationary Transformer. Please refer to :cite:`liu2022nonstationary` for more details.

Parameters
----------
X : torch.Tensor
Input data to be de-normalized. Shape: (n_samples, n_steps (seq_len), n_features).

means : torch.Tensor
Means values for de-normalization . Shape: (n_samples, n_features) or (n_samples, 1, n_features).

stdev : torch.Tensor
Standard deviation values for de-normalization. Shape: (n_samples, n_features) or (n_samples, 1, n_features).

Returns
-------
X_denorm : torch.Tensor
De-normalized data. Shape: (n_samples, n_steps (seq_len), n_features).

"""
assert (
len(X) == len(means) == len(stdev)
), "Input data and normalization parameters should have the same number of samples."
if len(means.shape) == 2:
means = means.unsqueeze(1)
if len(stdev.shape) == 2:
stdev = stdev.unsqueeze(1)

X = X * stdev # (stdev.repeat(1, n_steps, 1))
X = X + means # (means.repeat(1, n_steps, 1))
return X
6 changes: 6 additions & 0 deletions pypots/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Frequently-used modules like self-attention modules of vanilla Transformer are put in this package.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
File renamed without changes.
3 changes: 2 additions & 1 deletion pypots/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

__all__ = [
# content files in this package
"file.py",
"file",
"logging",
"metrics",
"random",
"visual",
]
Loading