Skip to content

Commit

Permalink
remove shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
malodetz committed Jul 9, 2023
1 parent 617d20f commit 7bf70f0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 16 deletions.
22 changes: 8 additions & 14 deletions etna/models/nn/patchts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
PositionalEncoding(d_model=self.hidden_size),
nn.TransformerEncoder(encoder_layers, self.num_layers)
)
self.max_patch_num = (encoder_length - 1 - self.patch_len) // self.stride + 1
self.max_patch_num = (encoder_length - self.patch_len) // self.stride + 1
self.projection = nn.Sequential(nn.Flatten(start_dim=-2),
nn.Linear(in_features=self.hidden_size * self.max_patch_num, out_features=1))

Expand All @@ -132,15 +132,13 @@ def forward(self, x: PatchTSBatch, *args, **kwargs): # type: ignore
:
forecast with shape (batch_size, decoder_length, 1)
"""
print(x)
encoder_real = x["encoder_real"].float() # (batch_size, encoder_length-1, input_size)
decoder_real = x["decoder_real"].float() # (batch_size, decoder_length, input_size)
decoder_length = decoder_real.shape[1]
outputs = []
x = encoder_real
for i in range(decoder_length):
pred = self._get_prediction(x)
print(pred)
outputs.append(pred)
x = torch.cat((x[:, 1:, :], torch.unsqueeze(pred, dim=1)), dim=1)

Expand All @@ -150,7 +148,7 @@ def forward(self, x: PatchTSBatch, *args, **kwargs): # type: ignore
return forecast

def _get_prediction(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1) # (batch_size, input_size, encoder_length + decoder_length - 1)
x = x.permute(0, 2, 1) # (batch_size, input_size, encoder_length)
# do patching
x = x.unfold(dimension=-1, size=self.patch_len,
step=self.stride) # (batch_size, input_size, patch_num, patch_len)
Expand All @@ -173,7 +171,7 @@ def step(self, batch: PatchTSBatch, *args, **kwargs): # type: ignore
:
loss, true_target, prediction_target
"""
encoder_real = batch["encoder_real"].float() # (batch_size, encoder_length-1, input_size)
encoder_real = batch["encoder_real"].float() # (batch_size, encoder_length, input_size)
decoder_real = batch["decoder_real"].float() # (batch_size, decoder_length, input_size)

decoder_target = batch["decoder_target"].float() # (batch_size, decoder_length, 1)
Expand All @@ -196,11 +194,7 @@ def step(self, batch: PatchTSBatch, *args, **kwargs): # type: ignore
def make_samples(self, df: pd.DataFrame, encoder_length: int, decoder_length: int) -> Iterator[dict]:
"""Make samples from segment DataFrame."""
values_real = (
df.select_dtypes(include=[np.number])
.assign(target_shifted=df["target"].shift(1))
.drop(["target"], axis=1)
.pipe(lambda x: x[["target_shifted"] + [i for i in x.columns if i != "target_shifted"]])
.values
df.select_dtypes(include=[np.number]).values
)
values_target = df["target"].values
segment = df["segment"].values[0]
Expand Down Expand Up @@ -232,10 +226,10 @@ def _make(

# Get shifted target and concatenate it with real values features
sample["encoder_real"] = values_real[start_idx: start_idx + encoder_length]
sample["encoder_real"] = sample["encoder_real"][1:]
sample["encoder_real"] = sample["encoder_real"]

target = values_target[start_idx: start_idx + encoder_length + decoder_length].reshape(-1, 1)
sample["encoder_target"] = target[1:encoder_length]
sample["encoder_target"] = target[:encoder_length]
sample["decoder_target"] = target[encoder_length:]

sample["segment"] = segment
Expand Down Expand Up @@ -278,8 +272,8 @@ def __init__(
nhead: int = 16,
lr: float = 1e-3,
loss: "torch.nn.Module" = nn.MSELoss(),
train_batch_size: int = 64,
test_batch_size: int = 64,
train_batch_size: int = 128,
test_batch_size: int = 128,
optimizer_params: Optional[dict] = None,
trainer_params: Optional[dict] = None,
train_dataloader_params: Optional[dict] = None,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_models/nn/test_patchts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"horizon",
[
8,
# 13,
# 15
13,
15
],
)
def test_patchts_model_run_weekly_overfit_with_scaler_small_patch(ts_dataset_weekly_function_with_horizon, horizon):
Expand All @@ -30,6 +30,7 @@ def test_patchts_model_run_weekly_overfit_with_scaler_small_patch(ts_dataset_wee
future.inverse_transform([std])

mae = MAE("macro")
print(mae(ts_test, future))
assert mae(ts_test, future) < 0.9


Expand Down

0 comments on commit 7bf70f0

Please sign in to comment.