Skip to content

Commit

Permalink
Reduce memory by garbage collecting matrix early
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret committed Oct 12, 2023
1 parent 4a07814 commit 62b8526
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/object_condensation/pytorch/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def condensation_loss(
``attractive``: Averaged over object, then averaged over all objects.
``repulsive``: Averaged like ``attractive``
``cl_peak``: Averaged over all objects
``cl_noise``
``cl_noise``: Averaged over all noise hits
"""
# x: n_nodes x n_outdim
not_noise = object_id > 0
unique_pids = torch.unique(object_id[not_noise])
assert len(unique_pids) > 0, "No particles found, cannot evaluate loss"
unique_oids = torch.unique(object_id[not_noise])
assert len(unique_oids) > 0, "No particles found, cannot evaluate loss"
# n_nodes x n_pids
# The nodes in every column correspond to the hits of a single particle and
# should attract each other
attractive_mask = object_id[:, None] == unique_pids[None, :]
attractive_mask = object_id[:, None] == unique_oids[None, :]

q = torch.arctanh(beta) ** 2 + q_min
assert not torch.isnan(q).any(), "q contains NaNs"
Expand All @@ -59,17 +59,23 @@ def condensation_loss(
# n_nodes x n_pids
dist = torch.cdist(x, x_alphas)

# It's important to directly do the .mean here so we don't keep these large
# matrices in memory longer than we need them
# Attractive potential (n_nodes x n_pids)
va = q[:, None] * attractive_mask * torch.square(dist) * q_alphas
va_matrix = q[:, None] * q_alphas * attractive_mask * torch.square(dist)
va = torch.mean(torch.mean(va_matrix[mask], dim=0))
# Repulsive potential (n_nodes x n_pids)
vr = q[:, None] * (~attractive_mask) * relu(radius_threshold - dist) * q_alphas
vr_matrix = (
q[:, None] * q_alphas * (~attractive_mask) * relu(radius_threshold - dist)
)
vr = torch.mean(torch.mean(vr_matrix, dim=0))

cl_peak = torch.mean(1 - beta[alphas])
cl_noise = torch.mean(beta[~not_noise])
peak = torch.mean(1 - beta[alphas])
noise = torch.mean(beta[~not_noise])

return {
"attractive": torch.mean(torch.mean(va[mask], dim=0)),
"repulsive": torch.mean(torch.mean(vr, dim=0)),
"cl_peak": cl_peak,
"cl_noise": cl_noise,
"attractive": va,
"repulsive": vr,
"peak": peak,
"noise": noise,
}

0 comments on commit 62b8526

Please sign in to comment.