From f9b9ebfeb7447c0183fa6498bce6d7886bd0c351 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Fri, 24 May 2024 18:05:23 -0400 Subject: [PATCH] Make doctest happy --- baal/active/heuristics/heuristics.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/baal/active/heuristics/heuristics.py b/baal/active/heuristics/heuristics.py index 02d440b..398b84f 100644 --- a/baal/active/heuristics/heuristics.py +++ b/baal/active/heuristics/heuristics.py @@ -769,11 +769,13 @@ def __init__(self, shuffle_prop=DEPRECATED, reverse=False, reduction="none"): def marginal_entropy_from_probs(self, probs): """ - Compute the marginal predictive entropy for each input, x_i: + Compute the marginal predictive entropy for each input, x_i. + + Equation: H[p(y|x_i)] = H[E_{q(θ)}[p(y|x_i,θ)]] ~= H[(1/K) Σ_{j=1}^K p(y|x_i,θ_j)] - where θ_j ~ q(θ) is a parameter sample and p(y|x_i,θ_j) is the parameter-conditional - predictive distribution for x_i and θ_j. + where θ_j ~ q(θ) is a parameter sample and p(y|x_i,θ_j) is the parameter-conditional + predictive distribution for x_i and θ_j. Args: probs (Tensor[float], [N, C, K]): p(y|x_i,θ_j) for i in [1, N] and j in [1, K]. @@ -788,17 +790,19 @@ def marginal_entropy_from_probs(self, probs): @requireprobs def compute_score(self, predictions, target_predictions): """ - Compute the expected predictive information gain for each candidate input, x_i: + Compute the expected predictive information gain for each candidate input, x_i. + + Equation: EPIG(x_i) = E_{p_*(x_*)}[I(y;y_*|x_i,x_*)] = H[p(y|x_i)] + E_{p_*(x_*)}[H[p(y_*|x_*)]] - E_{p_*(x_*)}[H[p(y,y_*|x_i,x_*)]] - where x_* ~ p_*(x_*) is a target input with unknown label y_*. + where x_* ~ p_*(x_*) is a target input with unknown label y_*. Args: predictions (ndarray, [N_p, C, K]): p(y|x_i,θ_j) for i in [1, N_p] and j in [1, K]. - target_predictions (ndarray, [N_t, C, K]): p(y|x_*^i,θ_j) - for i in [1, N_t] and j in [1, K]. + target_predictions (ndarray, [N_t, C, K]): Prediction from target distribution. + Or: p(y|x_*^i,θ_j) for i in [1, N_t] and j in [1, K]. Returns: scores (ndarray, [N,]): EPIG(x_i) for i in [1, N_p].