Skip to content

Commit

Permalink
Add tensorflow implementation (#5)
Browse files Browse the repository at this point in the history
* Crude implementation of OC loss

* make function external

* Allow for untracked "playground" directory everywhere

* tensorflow implementation **should** match torch implementation now

* fix type definitions
  • Loading branch information
phzehetn authored Oct 12, 2023
1 parent 844231f commit 25af69c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Specific locations
**/playground/**

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
77 changes: 77 additions & 0 deletions src/object_condensation/tensorflow/object_condensation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

import tensorflow as tf


def calculate_losses(
q_min: float,
object_id: tf.Tensor,
beta: tf.Tensor,
x: tf.Tensor,
weights: tf.Tensor=None,
noise_threshold: int=-1) -> dict[str, tf.Tensor]:
"""
Calculate the object condensation loss
"""
if weights is None:
weights = tf.ones_like(beta)

not_noise = object_id > noise_threshold
unique_oids, _ = tf.unique(object_id[not_noise])
q = tf.tanh(beta) ** 2 + q_min
mask_att = tf.cast(object_id[:, None] == unique_oids[None, :], tf.float32)
mask_rep = tf.cast(object_id[:, None] != unique_oids[None, :], tf.float32)
alphas = tf.argmax(beta * mask_att, axis=0)
beta_k = tf.gather(beta, alphas)
q_k = tf.gather(q, alphas)
x_k = tf.gather(x, alphas)
dist_j_k = tf.norm(x[None, :, :] - x_k[:, None, :], axis=-1)

v_att_k = tf.math.divide_no_nan(
tf.reduce_sum(
q_k
* tf.transpose(weights)
* tf.transpose(q)
* tf.transpose(mask_att)
* dist_j_k**2,
axis=1,
),
tf.reduce_sum(mask_att, axis=0) + 1e-3,
)
v_att = tf.divide_no_nan(
tf.reduce_sum(v_att_k), tf.cast(tf.shape(unique_oids)[0] - 1.0, tf.float32)
)

v_rep_k = tf.math.divide_no_nan(
tf.reduce_sum(
q_k
* tf.transpose(weights)
* tf.transpose(q)
* tf.transpose(mask_rep)
* tf.math.maximum(0, 1.0 - dist_j_k),
axis=1,
),
tf.reduce_sum(mask_rep, axis=0) + 1e-3,
)

v_rep = tf.divide_no_nan(
tf.reduce_sum(v_rep_k), tf.cast(tf.shape(unique_oids)[0] - 1.0, tf.float32)
)

noise_loss_k = 1.0 - beta_k
noise_loss = tf.divide_no_nan(
tf.reduce_sum(noise_loss_k[1:]),
tf.cast(tf.shape(unique_oids)[0] - 1.0, tf.float32),
)

coward_loss = tf.math.divide_no_nan(
tf.reduce_sum(beta[object_id == -1]),
tf.reduce_sum(tf.cast(object_id == -1, tf.float32)),
)

return {
"v_att": v_att,
"v_rep": v_rep,
"L_beta": noise_loss,
"L_noise": coward_loss,
}

0 comments on commit 25af69c

Please sign in to comment.