Skip to content

Commit

Permalink
learn scores in the orignal scale
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 20, 2024
1 parent 13d777d commit d8730ad
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def main(args):
)
# remove the very last param from the params
sliced_params = [p[:, :-1] for p in params]
distr = model.output_distribution(sliced_params)
distr = model.output_distribution(sliced_params, scale=scale)

# get the last target and calcualte its anomaly score
context_target = take_last(
inputs["past_target"], dim=-1, num=model.context_length - 1
)
# calculate the surprisal scores for the context target
scores = -distr.log_prob(context_target / scale)
scores = -distr.log_prob(context_target)

# get the args.top_score_percentage of the scores for each time series of the batch
top_scores = torch.topk(
Expand All @@ -149,15 +149,18 @@ def main(args):
)

# Loop over each prediction length
scaled_future_target = inputs["future_target"] / scale
distr = model.output_distribution(params, trailing_n=1)
distr = model.output_distribution(
params, trailing_n=1, scale=scale
)
batch_anomalies = []
for i in tqdm(
range(scaled_future_target.shape[1]),
range(inputs["future_target"].shape[1]),
desc="Processing prediction length",
leave=False,
):
score = -distr.log_prob(scaled_future_target[:, i : i + 1])
target = inputs["future_target"][:, i : i + 1]
score = -distr.log_prob(target)

# only check if its an anomaly for scores greater than gpd.loc for each entry in the batch
is_anomaly = torch.where(
score < gpd.loc,
Expand All @@ -173,19 +176,20 @@ def main(args):
),
dim=-1,
)

next_lags = lagged_sequence_values(
model.lags_seq,
inputs["past_target"] / scale,
scaled_future_target[:, i : i + 1],
target / scale,
dim=-1,
)
rnn_input = torch.cat((next_lags, next_features), dim=-1)

output, state = model.rnn(rnn_input, state)
params = model.param_proj(output)
distr = model.output_distribution(params)
distr = model.output_distribution(params, scale=scale)
# stack the batch_anomalies along the prediction length dimension
anomalies.append(torch.stack(batch_anomalies, dim=1))
# concat the anomalies along the batch dimension
anomalies = torch.cat(anomalies, dim=0).cpu().numpy()

# save as csv
Expand Down Expand Up @@ -238,7 +242,7 @@ def main(args):
parser.add_argument(
"--top_score_percentage",
type=float,
default=0.1,
default=0.2,
help="Percentage of top scores to consider for GPD fitting",
)
parser.add_argument(
Expand Down

0 comments on commit d8730ad

Please sign in to comment.