Skip to content

Commit

Permalink
fix next_lags
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 20, 2024
1 parent d8730ad commit 5a3419f
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def main(args):
distr = model.output_distribution(
params, trailing_n=1, scale=scale
)
scaled_past_target = inputs["past_target"] / scale
batch_anomalies = []
for i in tqdm(
range(inputs["future_target"].shape[1]),
Expand All @@ -178,13 +179,16 @@ def main(args):
)
next_lags = lagged_sequence_values(
model.lags_seq,
inputs["past_target"] / scale,
scaled_past_target,
target / scale,
dim=-1,
)
rnn_input = torch.cat((next_lags, next_features), dim=-1)

output, state = model.rnn(rnn_input, state)
scaled_past_target = torch.cat(
(scaled_past_target, target / scale), dim=1
)

params = model.param_proj(output)
distr = model.output_distribution(params, scale=scale)
# stack the batch_anomalies along the prediction length dimension
Expand Down

0 comments on commit 5a3419f

Please sign in to comment.