Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed May 24, 2024
1 parent 31fde79 commit 8e89179
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
16 changes: 9 additions & 7 deletions baal/active/heuristics/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def requireprobs(fn):
def wrapper(self, probabilities, target_predictions=None):
# Expected shape : [n_sample, n_classes, ..., n_iterations]
probabilities = to_prob(probabilities)
target_predictions = (
to_prob(target_predictions) if target_predictions is not None else None
)
target_predictions = to_prob(target_predictions) if target_predictions is not None else None
return fn(self, probabilities, target_predictions=target_predictions)

return wrapper
Expand Down Expand Up @@ -784,21 +782,23 @@ def marginal_entropy_from_probs(self, probs):
scores (Tensor[float], [N,]): H[p(y|x_i)] for i in [1, N].
"""
probs = torch.mean(probs, dim=-1) # [N, C]
scores = -torch.sum(torch.xlogy(probs, probs), dim=-1) # [N,]
scores = -torch.sum(torch.xlogy(probs, probs), dim=-1) # [N,]
return scores # [N,]

@requireprobs
def compute_score(self, predictions, target_predictions):
"""
Compute the expected predictive information gain for each candidate input, x_i:
EPIG(x_i) = E_{p_*(x_*)}[I(y;y_*|x_i,x_*)]
= H[p(y|x_i)] + E_{p_*(x_*)}[H[p(y_*|x_*)]] - E_{p_*(x_*)}[H[p(y,y_*|x_i,x_*)]]
= H[p(y|x_i)] + E_{p_*(x_*)}[H[p(y_*|x_*)]]
- E_{p_*(x_*)}[H[p(y,y_*|x_i,x_*)]]
where x_* ~ p_*(x_*) is a target input with unknown label y_*.
Args:
predictions (ndarray, [N_p, C, K]): p(y|x_i,θ_j) for i in [1, N_p] and j in [1, K].
target_predictions (ndarray, [N_t, C, K]): p(y|x_*^i,θ_j) for i in [1, N_t] and j in [1, K].
target_predictions (ndarray, [N_t, C, K]): p(y|x_*^i,θ_j)
for i in [1, N_t] and j in [1, K].
Returns:
scores (ndarray, [N,]): EPIG(x_i) for i in [1, N_p].
Expand All @@ -818,7 +818,9 @@ def compute_score(self, predictions, target_predictions):
probs_targ = probs_targ.reshape(K, N_t * C) # [K, N_t * C]
probs_joint = torch.matmul(probs_pool, probs_targ) / K # [N_p, C, N_t * C]

entropy_joint = -torch.sum(torch.xlogy(probs_joint, probs_joint), dim=(-2, -1)) / N_t # [N_p,]
entropy_joint = (
-torch.sum(torch.xlogy(probs_joint, probs_joint), dim=(-2, -1)) / N_t
) # [N_p,]
entropy_joint = torch.nan_to_num(entropy_joint, nan=0.0) # [N_p,]

scores = entropy_pool + torch.mean(entropy_targ) - entropy_joint # [N_p,]
Expand Down
8 changes: 2 additions & 6 deletions baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def __init__(
self.threshold = threshold
self.reversed = reverse
assert reduction in available_reductions or callable(reduction)
self.reduction = (
reduction if callable(reduction) else available_reductions[reduction]
)
self.reduction = reduction if callable(reduction) else available_reductions[reduction]

def compute_score(self, predictions, target_predictions=None):
"""
Expand Down Expand Up @@ -129,9 +127,7 @@ def predict_on_dataset(

def predict_on_batch(self, data, iterations=1, use_cuda=False):
"""Rank the predictions according to their uncertainties."""
return self.get_uncertainties(
self.model.predict_on_batch(data, iterations, cuda=use_cuda)
)
return self.get_uncertainties(self.model.predict_on_batch(data, iterations, cuda=use_cuda))


class BALDGPUWrapper(AbstractGPUHeuristic):
Expand Down

0 comments on commit 8e89179

Please sign in to comment.