Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory by garbage collecting matrix early #4

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 40 additions & 21 deletions src/object_condensation/pytorch/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def condensation_loss(
mask: T,
q_min: float,
radius_threshold: float,
noise_thld: int,
) -> dict[str, T]:
"""Condensation losses

Expand All @@ -29,47 +30,65 @@ def condensation_loss(
radius_threshold: Radius threshold for repulsive potential. In case of linear
scarlarization of the multi objective losses, this is redundant and should
be fixed to 1.
noise_thld: Threshold for noise hits. Hits with ``object_id <= noise_thld``
are considered to be noise

Returns:
Dictionary of scalar tensors.

``attractive``: Averaged over object, then averaged over all objects.
``repulsive``: Averaged like ``attractive``
``cl_peak``: Averaged over all objects
``cl_noise``
``cl_noise``: Averaged over all noise hits
"""
# x: n_nodes x n_outdim
not_noise = object_id > 0
unique_pids = torch.unique(object_id[not_noise])
assert len(unique_pids) > 0, "No particles found, cannot evaluate loss"
not_noise = object_id > noise_thld
unique_oids = torch.unique(object_id[not_noise])
assert len(unique_oids) > 0, "No particles found, cannot evaluate loss"
# n_nodes x n_pids
# The nodes in every column correspond to the hits of a single particle and
# should attract each other
attractive_mask = object_id[:, None] == unique_pids[None, :]
attractive_mask = object_id[:, None] == unique_oids[None, :]

q = torch.arctanh(beta) ** 2 + q_min
assert not torch.isnan(q).any(), "q contains NaNs"
# n_objs
alphas = torch.argmax(q[:, None] * attractive_mask, dim=0)

# n_pids x n_outdim
x_alphas = x[alphas]
# 1 x n_pids
q_alphas = q[alphas][None, :]
# _j means indexed by hits
# _k means indexed by objects

# n_nodes x n_pids
dist = torch.cdist(x, x_alphas)
# n_objs x n_outdim
x_k = x[alphas]
# 1 x n_objs
q_k = q[alphas][None, :]

dist_j_k = torch.cdist(x, x_k)

# Attractive potential/loss
# todo: do I need the copy/new axis here or would it broadcast?
v_att_j_k = q[:, None] * q_k * attractive_mask * torch.square(dist_j_k)
# It's important to directly do the .mean here so we don't keep these large
# matrices in memory longer than we need them
# Attractive potential per object normalized over number of hits in object
v_att_k = torch.sum(v_att_j_k[mask], dim=0) / torch.sum(
attractive_mask[mask], dim=0
)
v_att = torch.mean(v_att_k)

# Attractive potential (n_nodes x n_pids)
va = q[:, None] * attractive_mask * torch.square(dist) * q_alphas
# Repulsive potential (n_nodes x n_pids)
vr = q[:, None] * (~attractive_mask) * relu(radius_threshold - dist) * q_alphas
# Repulsive potential/loss
v_rep_j_k = (
q[:, None] * q_k * (~attractive_mask) * relu(radius_threshold - dist_j_k)
)
v_rep_k = torch.sum(v_rep_j_k, dim=0) / torch.sum(~attractive_mask, dim=0)
v_rep = torch.mean(v_rep_k)

cl_peak = torch.mean(1 - beta[alphas])
cl_noise = torch.mean(beta[~not_noise])
l_beta = torch.mean(1 - beta[alphas])
l_noise = torch.mean(beta[~not_noise])

return {
"attractive": torch.mean(torch.mean(va[mask], dim=0)),
"repulsive": torch.mean(torch.mean(vr, dim=0)),
"cl_peak": cl_peak,
"cl_noise": cl_noise,
"attractive": v_att,
"repulsive": v_rep,
"peak": l_beta,
"noise": l_noise,
}