Skip to content

Commit

Permalink
Add protection against 0 division
Browse files Browse the repository at this point in the history
(maybe this doesn't make sense but for consistency right now)
  • Loading branch information
klieret committed Oct 12, 2023
1 parent e5f940f commit 30a398c
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/object_condensation/pytorch/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def condensation_loss(
``cl_peak``: Averaged over all objects
``cl_noise``: Averaged over all noise hits
"""
# To protect against nan in divisions
eps = 1e-9

# x: n_nodes x n_outdim
not_noise = object_id > noise_thld
unique_oids = torch.unique(object_id[not_noise])
Expand Down Expand Up @@ -71,24 +74,24 @@ def condensation_loss(
# It's important to directly do the .mean here so we don't keep these large
# matrices in memory longer than we need them
# Attractive potential per object normalized over number of hits in object
v_att_k = torch.sum(v_att_j_k[mask], dim=0) / torch.sum(
attractive_mask[mask], dim=0
v_att_k = torch.sum(v_att_j_k[mask], dim=0) / (
torch.sum(attractive_mask[mask], dim=0) + eps
)
v_att = torch.mean(v_att_k)

# Repulsive potential/loss
v_rep_j_k = (
q[:, None] * q_k * (~attractive_mask) * relu(radius_threshold - dist_j_k)
)
v_rep_k = torch.sum(v_rep_j_k, dim=0) / torch.sum(~attractive_mask, dim=0)
v_rep_k = torch.sum(v_rep_j_k, dim=0) / (torch.sum(~attractive_mask, dim=0) + eps)
v_rep = torch.mean(v_rep_k)

l_beta = torch.mean(1 - beta[alphas])
l_coward = torch.mean(1 - beta[alphas])
l_noise = torch.mean(beta[~not_noise])

return {
"attractive": v_att,
"repulsive": v_rep,
"peak": l_beta,
"coward": l_coward,
"noise": l_noise,
}

0 comments on commit 30a398c

Please sign in to comment.