Skip to content

Commit

Permalink
Changes according to review
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed May 18, 2024
1 parent 0d3c024 commit 0cd798a
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 139 deletions.
1 change: 1 addition & 0 deletions baal/active/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ def get_heuristic(
"variance": heuristics.Variance,
"precomputed": heuristics.Precomputed,
"batch_bald": heuristics.BatchBALD,
"epig": heuristics.EPIG,
}[name](shuffle_prop=shuffle_prop, reduction=reduction, **kwargs)
return heuristic
14 changes: 11 additions & 3 deletions baal/active/active_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,18 @@ def step(self, pool=None) -> bool:
targets = None
else:
probs = self.get_probabilities(pool, **self.kwargs)
targets = self.get_probabilities(self.dataset, **self.kwargs)
if isinstance(self.heuristic, heuristics.EPIG):
targets = self.get_probabilities(self.dataset, **self.kwargs)
else:
targets = None
if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
to_label, uncertainty = self.heuristic.get_ranks(probs, targets) if type(self.heuristic) == heuristics.EPIG else self.heuristic.get_ranks(probs)

to_label, uncertainty = self.heuristic.get_ranks(probs, targets)
log.info(
"Uncertainty",
mean=uncertainty.mean(),
std=uncertainty.std(),
median=np.median(uncertainty),
)
if indices is not None:
to_label = indices[np.array(to_label)]
if self.uncertainty_folder is not None:
Expand Down
Loading

0 comments on commit 0cd798a

Please sign in to comment.