Skip to content

Commit

Permalink
use _gdk_domain_map
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 20, 2024
1 parent 55257e3 commit e127f2b
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions examples/anomaly_detection_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ def fit_gpd(data, num_iterations=100, learning_rate=0.001):
optimizer, mode="min", factor=0.5, patience=3
)

def _gdk_domain_map(loc, scale, concentration):
def _gdk_domain_map(loc, scale, concentration, validate_args=None):
scale = F.softplus(scale)
neg_conc = concentration < 0
loc = torch.where(neg_conc, loc - scale / concentration, loc)
return GeneralizedPareto(loc, scale, concentration)
return GeneralizedPareto(
loc, scale, concentration, validate_args=validate_args
)

def closure():
optimizer.zero_grad()
Expand All @@ -74,7 +76,7 @@ def closure():
for _ in range(num_iterations):
optimizer.step(closure)

return GeneralizedPareto(
return _gdk_domain_map(
loc.detach(),
scale.detach(),
concentration.detach(),
Expand Down

0 comments on commit e127f2b

Please sign in to comment.