Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the package pypots/modules, and appending modules of vanilla Transformer #173

Merged
merged 8 commits into from
Aug 20, 2023
69 changes: 44 additions & 25 deletions pypots/imputation/saits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from .data import DatasetForSAITS
from ..base import BaseNNImputer
from ..transformer.modules import EncoderLayer, PositionalEncoding
from ...data.base import BaseDataset
from ...modules.self_attention import EncoderLayer, PositionalEncoding
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.metrics import cal_mae
Expand All @@ -35,8 +35,8 @@ class _SAITS(nn.Module):
def __init__(
self,
n_layers: int,
d_time: int,
d_feature: int,
n_steps: int,
n_features: int,
d_model: int,
d_inner: int,
n_heads: int,
Expand All @@ -50,67 +50,69 @@ def __init__(
):
super().__init__()
self.n_layers = n_layers
actual_d_feature = d_feature * 2
self.n_steps = n_steps
# concatenate the feature vector and missing mask, hence double the number of features
actual_n_features = n_features * 2
self.diagonal_attention_mask = diagonal_attention_mask
self.ORT_weight = ORT_weight
self.MIT_weight = MIT_weight

self.layer_stack_for_first_block = nn.ModuleList(
[
EncoderLayer(
d_time,
actual_d_feature,
d_model,
d_inner,
n_heads,
d_k,
d_v,
dropout,
attn_dropout,
diagonal_attention_mask,
)
for _ in range(n_layers)
]
)
self.layer_stack_for_second_block = nn.ModuleList(
[
EncoderLayer(
d_time,
actual_d_feature,
d_model,
d_inner,
n_heads,
d_k,
d_v,
dropout,
attn_dropout,
diagonal_attention_mask,
)
for _ in range(n_layers)
]
)

self.dropout = nn.Dropout(p=dropout)
self.position_enc = PositionalEncoding(d_model, n_position=d_time)
# for operation on time dim
self.embedding_1 = nn.Linear(actual_d_feature, d_model)
self.reduce_dim_z = nn.Linear(d_model, d_feature)
# for operation on measurement dim
self.embedding_2 = nn.Linear(actual_d_feature, d_model)
self.reduce_dim_beta = nn.Linear(d_model, d_feature)
self.reduce_dim_gamma = nn.Linear(d_feature, d_feature)
self.position_enc = PositionalEncoding(d_model, n_position=n_steps)
# for the 1st block
self.embedding_1 = nn.Linear(actual_n_features, d_model)
self.reduce_dim_z = nn.Linear(d_model, n_features)
# for the 2nd block
self.embedding_2 = nn.Linear(actual_n_features, d_model)
self.reduce_dim_beta = nn.Linear(d_model, n_features)
self.reduce_dim_gamma = nn.Linear(n_features, n_features)
# for delta decay factor
self.weight_combine = nn.Linear(d_feature + d_time, d_feature)
self.weight_combine = nn.Linear(n_features + n_steps, n_features)

def _process(self, inputs: dict) -> Tuple[torch.Tensor, list]:
def _process(
self,
inputs: dict,
diagonal_attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, list]:
X, masks = inputs["X"], inputs["missing_mask"]

# first DMSA block
input_X_for_first = torch.cat([X, masks], dim=2)
input_X_for_first = self.embedding_1(input_X_for_first)
enc_output = self.dropout(
self.position_enc(input_X_for_first)
) # namely, term e in the math equation
for encoder_layer in self.layer_stack_for_first_block:
enc_output, _ = encoder_layer(enc_output)
enc_output, _ = encoder_layer(enc_output, diagonal_attention_mask)

X_tilde_1 = self.reduce_dim_z(enc_output)
X_prime = masks * X + (1 - masks) * X_tilde_1
Expand Down Expand Up @@ -146,9 +148,23 @@ def _process(self, inputs: dict) -> Tuple[torch.Tensor, list]:

return X_c, [X_tilde_1, X_tilde_2, X_tilde_3]

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(
self, inputs: dict, diagonal_attention_mask: bool = False, training: bool = True
) -> dict:
X, masks = inputs["X"], inputs["missing_mask"]
imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process(inputs)

if (training and self.diagonal_attention_mask) or (
(not training) and diagonal_attention_mask
):
diagonal_attention_mask = torch.eye(self.n_steps).to(X.device)
# then broadcast on the batch axis
diagonal_attention_mask = diagonal_attention_mask.unsqueeze(0)
else:
diagonal_attention_mask = None

imputed_data, [X_tilde_1, X_tilde_2, X_tilde_3] = self._process(
inputs, diagonal_attention_mask
)

if not training:
# if not in training mode, return the classification result only
Expand Down Expand Up @@ -427,7 +443,8 @@ def fit(
def impute(
self,
X: Union[dict, str],
file_type="h5py",
file_type: str = "h5py",
diagonal_attention_mask: bool = True,
) -> np.ndarray:
# Step 1: wrap the input data with classes Dataset and DataLoader
self.model.eval() # set the model as eval status to freeze it.
Expand All @@ -444,7 +461,9 @@ def impute(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(
inputs, diagonal_attention_mask, training=False
)
imputed_data = results["imputed_data"]
imputation_collector.append(imputed_data)

Expand Down
5 changes: 1 addition & 4 deletions pypots/imputation/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from torch.utils.data import DataLoader

from .data import DatasetForSAITS
from .modules import EncoderLayer, PositionalEncoding
from ..base import BaseNNImputer
from ...data.base import BaseDataset
from ...modules.self_attention import EncoderLayer, PositionalEncoding
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.metrics import cal_mae
Expand Down Expand Up @@ -55,16 +55,13 @@ def __init__(
self.layer_stack = nn.ModuleList(
[
EncoderLayer(
d_time,
actual_d_feature,
d_model,
d_inner,
n_heads,
d_k,
d_v,
dropout,
attn_dropout,
False,
)
for _ in range(n_layers)
]
Expand Down
188 changes: 0 additions & 188 deletions pypots/imputation/transformer/modules.py

This file was deleted.

6 changes: 6 additions & 0 deletions pypots/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Frequently-used modules like self-attention modules of vanilla Transformer are put in this package.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GLP-v3
Loading