diff --git a/baal/active/__init__.py b/baal/active/__init__.py index 20c46ae..f059ffe 100644 --- a/baal/active/__init__.py +++ b/baal/active/__init__.py @@ -30,5 +30,6 @@ def get_heuristic( "variance": heuristics.Variance, "precomputed": heuristics.Precomputed, "batch_bald": heuristics.BatchBALD, + "epig": heuristics.EPIG, }[name](shuffle_prop=shuffle_prop, reduction=reduction, **kwargs) return heuristic diff --git a/baal/active/active_loop.py b/baal/active/active_loop.py index 2c5eee3..5a8e66f 100644 --- a/baal/active/active_loop.py +++ b/baal/active/active_loop.py @@ -84,10 +84,18 @@ def step(self, pool=None) -> bool: targets = None else: probs = self.get_probabilities(pool, **self.kwargs) - targets = self.get_probabilities(self.dataset, **self.kwargs) + if isinstance(self.heuristic, heuristics.EPIG): + targets = self.get_probabilities(self.dataset, **self.kwargs) + else: + targets = None if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0): - to_label, uncertainty = self.heuristic.get_ranks(probs, targets) if type(self.heuristic) == heuristics.EPIG else self.heuristic.get_ranks(probs) - + to_label, uncertainty = self.heuristic.get_ranks(probs, targets) + log.info( + "Uncertainty", + mean=uncertainty.mean(), + std=uncertainty.std(), + median=np.median(uncertainty), + ) if indices is not None: to_label = indices[np.array(to_label)] if self.uncertainty_folder is not None: diff --git a/baal/active/heuristics/heuristics.py b/baal/active/heuristics/heuristics.py index a751d3b..239cf50 100644 --- a/baal/active/heuristics/heuristics.py +++ b/baal/active/heuristics/heuristics.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Sequence from functools import wraps as _wraps +from itertools import zip_longest from typing import List import numpy as np @@ -47,11 +48,13 @@ def singlepass(fn): """ @_wraps(fn) - def wrapper(self, probabilities): + def wrapper(self, probabilities, training_predictions=None): if probabilities.ndim >= 3: # Expected shape : [n_sample, n_classes, ..., n_iterations] probabilities = probabilities.mean(-1) - return fn(self, probabilities) + if training_predictions is not None and training_predictions.ndim >= 3: + training_predictions = training_predictions.mean(-1) + return fn(self, probabilities, training_predictions) return wrapper @@ -68,10 +71,13 @@ def requireprobs(fn): """ @_wraps(fn) - def wrapper(self, probabilities): + def wrapper(self, probabilities, training_predictions=None): # Expected shape : [n_sample, n_classes, ..., n_iterations] probabilities = to_prob(probabilities) - return fn(self, probabilities) + training_predictions = ( + to_prob(training_predictions) if training_predictions is not None else None + ) + return fn(self, probabilities, training_predictions=training_predictions) return wrapper @@ -90,7 +96,7 @@ def require_single_item(fn): """ @_wraps(fn) - def wrapper(self, probabilities): + def wrapper(self, probabilities, training_predictions=None): # Expected single shape : [n_sample, n_classes, ..., n_iterations] if isinstance(probabilities, (list, tuple)): if len(probabilities) == 1: @@ -103,7 +109,7 @@ def wrapper(self, probabilities): " we suggest using baal.active.heuristics.CombineHeuristics" ) - return fn(self, probabilities) + return fn(self, probabilities, training_predictions) return wrapper @@ -157,24 +163,26 @@ def __init__(self, shuffle_prop=DEPRECATED, reverse=False, reduction="none"): self._reduction_name = reduction self.reduction = reduction if callable(reduction) else available_reductions[reduction] - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): """ Compute the score according to the heuristic. Args: predictions (ndarray): Array of predictions + training_predictions (Optional[ndarray]): Array of predictions on train set. Returns: Array of scores. """ raise NotImplementedError - def get_uncertainties_generator(self, predictions): + def get_uncertainties_generator(self, predictions, training_predictions=None): """ Compute the score according to the heuristic. Args: predictions (Iterable): Generator of predictions + training_predictions (Optional[ndarray]): Generator of training predictions Raises: ValueError if the generator is empty. @@ -184,17 +192,18 @@ def get_uncertainties_generator(self, predictions): """ acc = [] for pred in predictions: - acc.append(self.get_uncertainties(pred)) + acc.append(self.get_uncertainties(pred, training_predictions=training_predictions)) if len(acc) == 0: raise ValueError("No prediction! Cannot order the values!") return np.concatenate(acc) - def get_uncertainties(self, predictions): + def get_uncertainties(self, predictions, training_predictions=None): """ Get the uncertainties. Args: predictions (ndarray): Array of predictions + training_predictions (ndarray): Array of predictions on training set. Returns: Array of uncertainties @@ -202,7 +211,7 @@ def get_uncertainties(self, predictions): """ if isinstance(predictions, Tensor): predictions = predictions.numpy() - scores = self.compute_score(predictions) + scores = self.compute_score(predictions, training_predictions=training_predictions) scores = self.reduction(scores) if not np.all(np.isfinite(scores)): fixed = 0.0 if self.reversed else 10000 @@ -242,12 +251,13 @@ def reorder_indices(self, scores): ranks = _shuffle_subset(ranks, self.shuffle_prop) return ranks - def get_ranks(self, predictions): + def get_ranks(self, predictions, training_predictions=None): """ Rank the predictions according to their uncertainties. Args: predictions (ndarray): [batch_size, C, ..., Iterations] + training_predictions (Optional[ndarray]): [batch_size, C, ..., Iterations] Returns: Ranked index according to the uncertainty (highest to lowes). @@ -255,18 +265,20 @@ def get_ranks(self, predictions): """ if isinstance(predictions, types.GeneratorType): - scores = self.get_uncertainties_generator(predictions) + scores = self.get_uncertainties_generator( + predictions, training_predictions=training_predictions + ) else: - scores = self.get_uncertainties(predictions) + scores = self.get_uncertainties(predictions, training_predictions=training_predictions) return self.reorder_indices(scores), scores - def __call__(self, predictions): + def __call__(self, predictions, training_predictions=None): """Rank the predictions according to their uncertainties. Only return the scores and not the associated uncertainties. """ - return self.get_ranks(predictions)[0] + return self.get_ranks(predictions, training_predictions)[0] class BALD(AbstractHeuristic): @@ -288,12 +300,13 @@ def __init__(self, shuffle_prop=DEPRECATED, reduction="none"): @require_single_item @requireprobs - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): """ Compute the score according to the heuristic. Args: predictions (ndarray): Array of predictions + training_predictions (Optional[ndarray]): [batch_size, C, ..., Iterations] Returns: Array of scores. @@ -432,12 +445,13 @@ def _joint_entropy(self, predictions, selected): @require_single_item @requireprobs - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): """ Compute the score according to the heuristic. Args: predictions (ndarray): Array of predictions [batch_size, C, Iterations] + training_predictions (Optional[ndarray]): [batch_size, C, ..., Iterations] Notes: Only Classification is supported, not semantic segmentation or other. @@ -481,12 +495,13 @@ def compute_score(self, predictions): return uncertainties - def get_ranks(self, predictions): + def get_ranks(self, predictions, training_predictions=None): """ Rank the predictions according to their uncertainties. Args: predictions (ndarray): [batch_size, C, Iterations] + training_predictions (Optional[ndarray]): [batch_size, C, Iterations] Returns: Ranked index according to the uncertainty (highest to lowest). @@ -525,7 +540,7 @@ def __init__(self, shuffle_prop=DEPRECATED, reduction="mean"): super().__init__(shuffle_prop=shuffle_prop, reverse=True, reduction=reduction) @require_single_item - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): assert predictions.ndim >= 3 return np.var(predictions, -1) @@ -546,7 +561,7 @@ def __init__(self, shuffle_prop=DEPRECATED, reduction="none"): @require_single_item @singlepass @requireprobs - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): return scipy.stats.entropy(np.swapaxes(predictions, 0, 1)) @@ -568,7 +583,7 @@ def __init__(self, shuffle_prop=DEPRECATED, reduction="none"): @require_single_item @singlepass @requireprobs - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): sort_arr = np.sort(predictions, axis=1) return sort_arr[:, -1] - sort_arr[:, -2] @@ -587,7 +602,7 @@ def __init__(self, shuffle_prop=DEPRECATED, reduction="none"): @require_single_item @singlepass - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): return np.max(predictions, axis=1) @@ -602,7 +617,7 @@ class Precomputed(AbstractHeuristic): def __init__(self, shuffle_prop=DEPRECATED, reverse=False): super().__init__(shuffle_prop, reverse=reverse) - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): return predictions @@ -622,7 +637,7 @@ def __init__(self, shuffle_prop=DEPRECATED, reduction="none", seed=None): else: self.rng = np.random - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): return self.rng.rand(predictions.shape[0]) @@ -667,7 +682,7 @@ def __init__(self, heuristics: List, weights: List, reduction="mean", shuffle_pr else: raise Exception("heuristics should have the same value for `revesed` parameter") - def get_uncertainties(self, predictions): + def get_uncertainties(self, predictions, training_predictions=None): """ Computes the score for each part of predictions according to the assigned heuristic. @@ -677,6 +692,7 @@ def get_uncertainties(self, predictions): Args: predictions (list[ndarray]): list of predictions arrays + training_predictions (Optional[List[ndarray]): List of predictions on training dataset. Returns: Array of uncertainties @@ -684,11 +700,21 @@ def get_uncertainties(self, predictions): """ results = [] - for ind, prediction in enumerate(predictions): + for ind, (prediction, train_pred) in enumerate( + zip_longest(predictions, training_predictions or [], fillvalue=None) + ): if isinstance(predictions[0], types.GeneratorType): - results.append(self.composed_heuristic[ind].get_uncertainties_generator(prediction)) + results.append( + self.composed_heuristic[ind].get_uncertainties_generator( + prediction, training_predictions=train_pred + ) + ) else: - results.append(self.composed_heuristic[ind].get_uncertainties(prediction)) + results.append( + self.composed_heuristic[ind].get_uncertainties( + prediction, training_predictions=train_pred + ) + ) return results def reorder_indices(self, scores_list): @@ -732,6 +758,11 @@ class EPIG(AbstractHeuristic): Implementation of Expected Predicted Information Gain https://arxiv.org/abs/2304.08151 + Args: + shuffle_prop (float): DEPRECATED + reverse (bool): UNUSED + reduction (Union[str, callable]): function that aggregates the results. + References: Code from https://github.com/fbickfordsmith/epig """ @@ -739,52 +770,37 @@ class EPIG(AbstractHeuristic): def __init__(self, shuffle_prop=DEPRECATED, reverse=False, reduction="none"): super().__init__(shuffle_prop=shuffle_prop, reverse=True, reduction=reduction) - def entropy_from_probs(self,probs): - """ - See entropy_from_logprobs. - - If p(y=y'|x) is 0, we make sure p(y=y'|x) log p(y=y'|x) evaluates to 0, not NaN. - - Arguments: - probs: Tensor[float], [*N, Cl] - - Returns: - Tensor[float], [*N,] - """ - logprobs = torch.clone(probs) #  [*N, Cl] - logprobs[probs > 0] = torch.log(probs[probs > 0]) #  [*N, Cl] - return -torch.sum((probs * logprobs), dim=-1) # [*N,] - - def marginal_entropy_from_probs(self,probs): + def marginal_entropy_from_probs(self, probs): """ See marginal_entropy_from_logprobs. - Arguments: + Args: probs: Tensor[float], [N, Cl, K] Returns: Tensor[float], [N,] """ - probs = torch.mean(probs, dim=2) # [N, Cl] - scores = self.entropy_from_probs(probs) + probs = torch.mean(probs, dim=-1) # [N, Cl] + scores = -torch.sum(torch.xlogy(probs, probs), dim=-1) return scores # [N,] - def compute_score(self, predictions, targets): + @requireprobs + def compute_score(self, predictions, training_predictions): """ Compute the score according to the heuristic. Args: - predictions (ndarray): Array of predictions - targets (ndarray): Array of targets + predictions (ndarray): Array of predictions + training_predictions (ndarray): Array of targets Returns: Array of scores. """ - assert predictions.ndim >= 3 - assert targets.ndim >= 3 + assert predictions.ndim == 3, "EPIG only supports classification for now." + assert training_predictions.ndim == 3, "EPIG only supports classification for now." - probs_pool = torch.Tensor(predictions) #[N, Cl, K] - probs_targ = torch.Tensor(targets) + probs_pool = torch.Tensor(predictions) # [N, Cl, K] + probs_targ = torch.Tensor(training_predictions) N_t, C, K = probs_targ.shape @@ -792,79 +808,16 @@ def compute_score(self, predictions, targets): entropy_targ = self.marginal_entropy_from_probs(probs_targ) # [N_t,] entropy_targ = torch.mean(entropy_targ) # [1,] - #probs_pool = probs_pool.permute(0, 2, 1) # [N_p, Cl, K] + # probs_pool = probs_pool.permute(0, 2, 1) # [N_p, Cl, K] probs_targ = probs_targ.permute(2, 0, 1) # [K, N_t, Cl] probs_targ = probs_targ.reshape(K, N_t * C) # [K, N_t * Cl] - probs_pool_targ_joint = torch.matmul(probs_pool,probs_targ / K) # [N_p, Cl, N_t * Cl] + probs_pool_targ_joint = torch.matmul(probs_pool, probs_targ) / K # [N_p, Cl, N_t * Cl] entropy_pool_targ = ( - -torch.sum(xlogy(probs_pool_targ_joint, probs_pool_targ_joint), dim=(-2, -1)) / N_t + -torch.sum(torch.xlogy(probs_pool_targ_joint, probs_pool_targ_joint), dim=(-2, -1)) + / N_t ) # [N_p,] - + entropy_pool_targ[torch.isnan(entropy_pool_targ)] = 0.0 scores = entropy_pool + entropy_targ - entropy_pool_targ # [N_p,] return scores.numpy() # [N_p,] - - # scores = self._conditional_epig_from_probs(predictions, targets) - # return torch.mean(scores, dim=-1) # [N_p,] - - def get_uncertainties(self, predictions, targets): - """ - Get the uncertainties. - - Args: - predictions (ndarray): Array of predictions - - Returns: - Array of uncertainties - - """ - if isinstance(predictions, Tensor): - predictions = predictions.numpy() - scores = self.compute_score(predictions, targets) - scores = self.reduction(scores) - if not np.all(np.isfinite(scores)): - fixed = 0.0 if self.reversed else 10000 - warnings.warn(f"Invalid value in the score, will be put to {fixed}", UserWarning) - scores[~np.isfinite(scores)] = fixed - return scores - - def get_uncertainties_generator(self, predictions, targets): - """ - Compute the score according to the heuristic. - - Args: - predictions (Iterable): Generator of predictions - - Raises: - ValueError if the generator is empty. - - Returns: - Array of scores. - """ - acc = [] - for pred in predictions: - acc.append(self.get_uncertainties(pred, targets)) - if len(acc) == 0: - raise ValueError("No prediction! Cannot order the values!") - return np.concatenate(acc) - def get_ranks(self, predictions, targets): - """ - Rank the predictions according to their uncertainties. - - Args: - predictions (ndarray): [batch_size, C, ..., Iterations] - - Returns: - Ranked index according to the uncertainty (highest to lowes). - Scores for all predictions. - - """ - if isinstance(predictions, types.GeneratorType): - scores = self.get_uncertainties_generator(predictions, targets) - else: - scores = self.get_uncertainties(predictions, targets) - - print(scores) - - return self.reorder_indices(scores), scores diff --git a/baal/active/heuristics/heuristics_gpu.py b/baal/active/heuristics/heuristics_gpu.py index 724cc12..8d7ea2c 100644 --- a/baal/active/heuristics/heuristics_gpu.py +++ b/baal/active/heuristics/heuristics_gpu.py @@ -34,15 +34,22 @@ def _shuffle_subset(data: torch.Tensor, shuffle_prop: float) -> torch.Tensor: return data +def to_prob_torch(probabilities): + bounded = torch.min(probabilities) < 0 or torch.max(probabilities) > 1.0 + if bounded or not probabilities.sum(1).allclose(1): + probabilities = F.softmax(probabilities, 1) + return probabilities + + def requireprobs(fn): """Will convert logits to probs if needed""" - def wrapper(self, probabilities): + def wrapper(self, probabilities, training_predictions=None): # Expected shape : [n_sample, n_classes, ..., n_iterations] - bounded = torch.min(probabilities) < 0 or torch.max(probabilities) > 1.0 - if bounded or not probabilities.sum(1).allclose(1): - probabilities = F.softmax(probabilities, 1) - return fn(self, probabilities) + probabilities = to_prob_torch(probabilities) + if training_predictions is not None: + training_predictions = to_prob_torch(training_predictions) + return fn(self, probabilities, training_predictions) return wrapper @@ -71,13 +78,16 @@ def __init__( self.threshold = threshold self.reversed = reverse assert reduction in available_reductions or callable(reduction) - self.reduction = reduction if callable(reduction) else available_reductions[reduction] + self.reduction = ( + reduction if callable(reduction) else available_reductions[reduction] + ) - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): """ Compute the score according to the heuristic. Args: predictions (ndarray): Array of predictions + training_predictions (ndarray): Array of predictions from training set. Returns: Array of scores. @@ -105,14 +115,23 @@ def predict_on_dataset( return ( super() .predict_on_dataset( - dataset, batch_size, iterations, use_cuda, workers, collate_fn, half, verbose + dataset, + batch_size, + iterations, + use_cuda, + workers, + collate_fn, + half, + verbose, ) .reshape([-1]) ) def predict_on_batch(self, data, iterations=1, use_cuda=False): """Rank the predictions according to their uncertainties.""" - return self.get_uncertainties(self.model.predict_on_batch(data, iterations, cuda=use_cuda)) + return self.get_uncertainties( + self.model.predict_on_batch(data, iterations, cuda=use_cuda) + ) class BALDGPUWrapper(AbstractGPUHeuristic): @@ -122,7 +141,12 @@ class BALDGPUWrapper(AbstractGPUHeuristic): """ def __init__( - self, model: ModelWrapper, criterion, shuffle_prop=0.0, threshold=None, reduction="none" + self, + model: ModelWrapper, + criterion, + shuffle_prop=0.0, + threshold=None, + reduction="none", ): super().__init__( model, @@ -134,7 +158,7 @@ def __init__( ) @requireprobs - def compute_score(self, predictions): + def compute_score(self, predictions, training_predictions=None): assert predictions.ndimension() >= 3 # [n_sample, n_class, ..., n_iterations] expected_entropy = -torch.mean( diff --git a/tests/active/heuristic_test.py b/tests/active/heuristic_test.py index a342f75..5811765 100644 --- a/tests/active/heuristic_test.py +++ b/tests/active/heuristic_test.py @@ -9,6 +9,7 @@ from baal.active.heuristics import ( Random, BALD, + EPIG, Margin, Entropy, Certainty, @@ -66,6 +67,36 @@ def test_bald(distributions, reduction): # Unlikely, but not 100% sure assert np.any(marg != [1, 2, 0]) +@pytest.mark.parametrize( + 'distributions, reduction', + [ + (distributions_3d, 'none'), + ], +) +def test_epig(distributions, reduction): + np.random.seed(1338) + train_preds = distributions + + epig = EPIG(reduction=reduction) + marg = epig(distributions, train_preds) + str_marg = epig(chunks(distributions, 10), train_preds) + + # EPIG uses mean entropy of the unlablled predictions, so it's not stable. + assert np.allclose( + epig.get_uncertainties(distributions, train_preds), + epig.get_uncertainties_generator(chunks(distributions, 10), train_preds), + rtol=.05 + ) + + assert np.all(marg == [1, 2, 0]), "BALD is not right {}".format(marg) + assert np.all(str_marg == [1, 2, 0]), "StreamingBALD is not right {}".format(marg) + + epig = EPIG(0.99, reduction=reduction) + marg = epig(distributions, train_preds) + + # Unlikely, but not 100% sure + assert np.any(marg != [1, 2, 0]) + @pytest.mark.parametrize('distributions, reduction', [(distributions_3d, 'none')]) @@ -193,7 +224,7 @@ def test_that_logits_get_converted_to_probabilities(logits): # define a random func: @requireprobs - def wrapped(_, logits): + def wrapped(_, logits, training_predictions=None): return logits probability_distribution = wrapped(None, logits)