Skip to content

Commit

Permalink
fit returns GPD without validation
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 20, 2024
1 parent e34a600 commit 55257e3
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def fit_gpd(data, num_iterations=100, learning_rate=0.001):
learning_rate (float): Learning rate for the optimizer
Returns:
GeneralizedPareto: Fitted GPD(loc, scale, concentration) distribution
GeneralizedPareto: Fitted GPD(loc, scale, concentration) distribution without any validation
"""
batch_size, _ = data.shape

Expand Down Expand Up @@ -75,7 +75,10 @@ def closure():
optimizer.step(closure)

return GeneralizedPareto(
loc.detach(), scale.detach(), concentration.detach()
loc.detach(),
scale.detach(),
concentration.detach(),
validate_args=False,
)


Expand Down Expand Up @@ -152,12 +155,11 @@ def main(args):
leave=False,
):
score = -distr.log_prob(scaled_future_target[:, i : i + 1])
# check if the score are less than gpd.loc? for each entry in the batch
is_anomaly = score < gpd.loc
# mask out the score where is_anomaly is True
score = torch.where(is_anomaly, gpd.loc + 1, score)
# only check if its an anomaly for scores greater than gpd.loc for each entry in the batch
is_anomaly = torch.where(
is_anomaly, False, gpd.cdf(score) < args.anomaly_threshold
score < gpd.loc,
False,
gpd.cdf(score) < args.anomaly_threshold,
)
batch_anomalies.append(is_anomaly)

Expand Down

0 comments on commit 55257e3

Please sign in to comment.