Skip to content

Commit

Permalink
1 - gpd.cdf(score) < threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 21, 2024
1 parent 5a3419f commit 25143bc
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def main(args):
prediction_length=dataset.metadata.prediction_length,
context_length=args.context_length,
freq=dataset.metadata.freq,
num_feat_static_cat=len(dataset.metadata.feat_static_cat),
cardinality=[
int(cat_feat_info.cardinality)
for cat_feat_info in dataset.metadata.feat_static_cat
],
embedding_dimension=[3],
trainer_kwargs=dict(
max_epochs=args.max_epochs,
),
Expand All @@ -108,6 +114,7 @@ def main(args):
)

anomalies = []
means = []
model.eval()
with torch.no_grad():
for batch in tqdm(test_data_loader, desc="Processing batches"):
Expand Down Expand Up @@ -154,19 +161,20 @@ def main(args):
)
scaled_past_target = inputs["past_target"] / scale
batch_anomalies = []
batch_means = []
for i in tqdm(
range(inputs["future_target"].shape[1]),
desc="Processing prediction length",
leave=False,
):
target = inputs["future_target"][:, i : i + 1]
score = -distr.log_prob(target)

batch_means.append(distr.mean)
# only check if its an anomaly for scores greater than gpd.loc for each entry in the batch
is_anomaly = torch.where(
score < gpd.loc,
False,
gpd.cdf(score) < args.anomaly_threshold,
1 - gpd.cdf(score) < args.anomaly_threshold,
)
batch_anomalies.append(is_anomaly)

Expand All @@ -193,14 +201,19 @@ def main(args):
distr = model.output_distribution(params, scale=scale)
# stack the batch_anomalies along the prediction length dimension
anomalies.append(torch.stack(batch_anomalies, dim=1))
means.append(torch.stack(batch_means, dim=1))
# concat the anomalies along the batch dimension
anomalies = torch.cat(anomalies, dim=0).cpu().numpy()
means = torch.cat(means, dim=0).cpu().numpy()

# save as csv
# save as pkl
all_dates = []
all_flags = []
all_targets = []
for i, (entry, flags) in enumerate(zip(dataset.test, anomalies)):
all_means = []
for i, (entry, flags, mean) in enumerate(
zip(dataset.test, anomalies, means)
):
start_date = entry["start"].to_timestamp()
target = entry["target"]
dates = pd.date_range(
Expand All @@ -212,13 +225,19 @@ def main(args):
all_dates.append(date_index)
all_flags.append(flags.flatten().astype(bool))
all_targets.append(target_slice)
all_means.append(mean.flatten())

# create a dataframe with the date_index and the flags
anomaly_df = pd.DataFrame(
{"date": all_dates, "is_anomaly": all_flags, "target": all_targets}
{
"date": all_dates,
"is_anomaly": all_flags,
"target": all_targets,
"mean": all_means,
}
)
anomaly_df.set_index("date", inplace=True)
anomaly_df.to_csv(f"anomalies_{args.dataset}.csv")
anomaly_df.to_pickle(f"anomalies_{args.dataset}.pkl")


if __name__ == "__main__":
Expand Down Expand Up @@ -246,7 +265,7 @@ def main(args):
parser.add_argument(
"--top_score_percentage",
type=float,
default=0.2,
default=0.1,
help="Percentage of top scores to consider for GPD fitting",
)
parser.add_argument(
Expand Down

0 comments on commit 25143bc

Please sign in to comment.