-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathPatchTST.py
227 lines (197 loc) · 8.86 KB
/
PatchTST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import torch
from torch import nn
from layers.Transformer_EncDec import Encoder, EncoderLayer
from layers.SelfAttention_Family import FullAttention, AttentionLayer
from layers.Embed import PatchEmbedding
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous: return x.transpose(*self.dims).contiguous()
else: return x.transpose(*self.dims)
class FlattenHead(nn.Module):
def __init__(self, n_vars, nf, target_window, head_dropout=0):
super().__init__()
self.n_vars = n_vars
self.flatten = nn.Flatten(start_dim=-2)
self.linear = nn.Linear(nf, target_window)
self.dropout = nn.Dropout(head_dropout)
def forward(self, x): # x: [bs x nvars x d_model x patch_num]
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
return x
class Model(nn.Module):
"""
Paper link: https://arxiv.org/pdf/2211.14730.pdf
"""
def __init__(self, configs, patch_len=16, stride=8):
"""
patch_len: int, patch len for patch_embedding
stride: int, stride for patch_embedding
"""
super().__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
padding = stride
# patching and embedding
self.patch_embedding = PatchEmbedding(
configs.d_model, patch_len, stride, padding, configs.dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False), configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for l in range(configs.e_layers)
],
norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2))
)
# Prediction Head
self.head_nf = configs.d_model * \
int((configs.seq_len - patch_len) / stride + 2)
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len,
head_dropout=configs.dropout)
elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len,
head_dropout=configs.dropout)
elif self.task_name == 'classification':
self.flatten = nn.Flatten(start_dim=-2)
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(
self.head_nf * configs.enc_in, configs.num_class)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
# z: [bs * nvars x patch_num x d_model]
enc_out, attns = self.encoder(enc_out)
# z: [bs x nvars x patch_num x d_model]
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# z: [bs x nvars x d_model x patch_num]
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
dec_out = dec_out.permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
# Normalization from Non-stationary Transformer
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
means = means.unsqueeze(1).detach()
x_enc = x_enc - means
x_enc = x_enc.masked_fill(mask == 0, 0)
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
torch.sum(mask == 1, dim=1) + 1e-5)
stdev = stdev.unsqueeze(1).detach()
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
# z: [bs * nvars x patch_num x d_model]
enc_out, attns = self.encoder(enc_out)
# z: [bs x nvars x patch_num x d_model]
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# z: [bs x nvars x d_model x patch_num]
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
dec_out = dec_out.permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
return dec_out
def anomaly_detection(self, x_enc):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
# z: [bs * nvars x patch_num x d_model]
enc_out, attns = self.encoder(enc_out)
# z: [bs x nvars x patch_num x d_model]
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# z: [bs x nvars x d_model x patch_num]
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
dec_out = self.head(enc_out) # z: [bs x nvars x target_window]
dec_out = dec_out.permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
return dec_out
def classification(self, x_enc, x_mark_enc):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# do patching and embedding
x_enc = x_enc.permute(0, 2, 1)
# u: [bs * nvars x patch_num x d_model]
enc_out, n_vars = self.patch_embedding(x_enc)
# Encoder
# z: [bs * nvars x patch_num x d_model]
enc_out, attns = self.encoder(enc_out)
# z: [bs x nvars x patch_num x d_model]
enc_out = torch.reshape(
enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1]))
# z: [bs x nvars x d_model x patch_num]
enc_out = enc_out.permute(0, 1, 3, 2)
# Decoder
output = self.flatten(enc_out)
output = self.dropout(output)
output = output.reshape(output.shape[0], -1)
output = self.projection(output) # (batch_size, num_classes)
return output
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None