Skip to content

Commit

Permalink
epig implementation using matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
reeshipaul authored and Dref360 committed May 18, 2024
1 parent c50caba commit 0d3c024
Showing 1 changed file with 58 additions and 36 deletions.
94 changes: 58 additions & 36 deletions baal/active/heuristics/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,39 +739,36 @@ class EPIG(AbstractHeuristic):
def __init__(self, shuffle_prop=DEPRECATED, reverse=False, reduction="none"):
super().__init__(shuffle_prop=shuffle_prop, reverse=True, reduction=reduction)

def _conditional_epig_from_probs(self, predictions, targets):
# converting to Tensor
probs_pool = torch.Tensor(predictions)
probs_targ = torch.Tensor(targets)

# Estimate the joint predictive distribution.
probs_pool = probs_pool.permute(1, 0, 2) # [K, N_p, Cl]
probs_targ = probs_targ.permute(1, 0, 2) # [K, N_t, Cl]
probs_pool = probs_pool[:, :, None, :, None] # [K, N_p, 1, Cl, 1]
probs_targ = probs_targ[:, None, :, None, :] # [K, 1, N_t, 1, Cl]
probs_pool_targ_joint = probs_pool * probs_targ # [K, N_p, N_t, Cl, Cl]
probs_pool_targ_joint = torch.mean(probs_pool_targ_joint, dim=0) # [N_p, N_t, Cl, Cl]

# Estimate the marginal predictive distributions.
probs_pool = torch.mean(probs_pool, dim=0) # [N_p, 1, Cl, 1]
probs_targ = torch.mean(probs_targ, dim=0) # [1, N_t, 1, Cl]

# Estimate the product of the marginal predictive distributions.
probs_pool_targ_indep = probs_pool * probs_targ # [N_p, N_t, Cl, Cl]

# Estimate the conditional expected predictive information gain for each pair of examples.
# This is the KL divergence between probs_pool_targ_joint and probs_pool_targ_joint_indep.
nonzero_joint = probs_pool_targ_joint > 0 # [N_p, N_t, Cl, Cl]
log_term = torch.clone(probs_pool_targ_joint) # [N_p, N_t, Cl, Cl]
log_term[nonzero_joint] = torch.log(
probs_pool_targ_joint[nonzero_joint]
) # [N_p, N_t, Cl, Cl]
log_term[nonzero_joint] -= torch.log(
probs_pool_targ_indep[nonzero_joint]
) # [N_p, N_t, Cl, Cl]
scores = torch.sum(probs_pool_targ_joint * log_term, dim=(-2, -1)) # [N_p, N_t]
return scores # [N_p, N_t]
def entropy_from_probs(self,probs):
"""
See entropy_from_logprobs.
If p(y=y'|x) is 0, we make sure p(y=y'|x) log p(y=y'|x) evaluates to 0, not NaN.
Arguments:
probs: Tensor[float], [*N, Cl]
Returns:
Tensor[float], [*N,]
"""
logprobs = torch.clone(probs) #  [*N, Cl]
logprobs[probs > 0] = torch.log(probs[probs > 0]) #  [*N, Cl]
return -torch.sum((probs * logprobs), dim=-1) # [*N,]

def marginal_entropy_from_probs(self,probs):
"""
See marginal_entropy_from_logprobs.
Arguments:
probs: Tensor[float], [N, Cl, K]
Returns:
Tensor[float], [N,]
"""
probs = torch.mean(probs, dim=2) # [N, Cl]
scores = self.entropy_from_probs(probs)
return scores # [N,]

def compute_score(self, predictions, targets):
"""
Compute the score according to the heuristic.
Expand All @@ -785,8 +782,31 @@ def compute_score(self, predictions, targets):
"""
assert predictions.ndim >= 3
assert targets.ndim >= 3
scores = self._conditional_epig_from_probs(predictions, targets)
return torch.mean(scores, dim=-1) # [N_p,]

probs_pool = torch.Tensor(predictions) #[N, Cl, K]
probs_targ = torch.Tensor(targets)

N_t, C, K = probs_targ.shape

entropy_pool = self.marginal_entropy_from_probs(probs_pool) # [N_p,]
entropy_targ = self.marginal_entropy_from_probs(probs_targ) # [N_t,]
entropy_targ = torch.mean(entropy_targ) # [1,]

#probs_pool = probs_pool.permute(0, 2, 1) # [N_p, Cl, K]
probs_targ = probs_targ.permute(2, 0, 1) # [K, N_t, Cl]
probs_targ = probs_targ.reshape(K, N_t * C) # [K, N_t * Cl]
probs_pool_targ_joint = torch.matmul(probs_pool,probs_targ / K) # [N_p, Cl, N_t * Cl]

entropy_pool_targ = (
-torch.sum(xlogy(probs_pool_targ_joint, probs_pool_targ_joint), dim=(-2, -1)) / N_t
) # [N_p,]

entropy_pool_targ[torch.isnan(entropy_pool_targ)] = 0.0
scores = entropy_pool + entropy_targ - entropy_pool_targ # [N_p,]
return scores.numpy() # [N_p,]

# scores = self._conditional_epig_from_probs(predictions, targets)
# return torch.mean(scores, dim=-1) # [N_p,]

def get_uncertainties(self, predictions, targets):
"""
Expand All @@ -809,7 +829,7 @@ def get_uncertainties(self, predictions, targets):
scores[~np.isfinite(scores)] = fixed
return scores

def get_uncertainties_generator(self, predictions):
def get_uncertainties_generator(self, predictions, targets):
"""
Compute the score according to the heuristic.
Expand All @@ -824,7 +844,7 @@ def get_uncertainties_generator(self, predictions):
"""
acc = []
for pred in predictions:
acc.append(self.get_uncertainties(pred))
acc.append(self.get_uncertainties(pred, targets))
if len(acc) == 0:
raise ValueError("No prediction! Cannot order the values!")
return np.concatenate(acc)
Expand All @@ -844,5 +864,7 @@ def get_ranks(self, predictions, targets):
scores = self.get_uncertainties_generator(predictions, targets)
else:
scores = self.get_uncertainties(predictions, targets)

print(scores)

return self.reorder_indices(scores), scores

0 comments on commit 0d3c024

Please sign in to comment.