Skip to content

Commit

Permalink
Adding hit ratio metric (#7)
Browse files Browse the repository at this point in the history
* adding hit ratio metric

* adding optional return of query history in full active learning run method

* specify version in single file and fix documentation

* readme semver update

* adding check on hit ratio at percentage

* make sure hit ratio is in final print

Co-authored-by: thomasgaudelet <thomasgaudelet@github.com>
  • Loading branch information
thomasgaudelet and thomasgaudelet authored Mar 28, 2022
1 parent 31998b4 commit 41edede
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<a alt="coverage">
<img src="https://img.shields.io/badge/coverage-93%25-green" /></a>
<a alt="semver">
<img src="https://img.shields.io/badge/semver-0.1.4-blue" /></a>
<img src="https://img.shields.io/badge/semver-0.1.5-blue" /></a>
<a alt="documentation" href="https://pyrelational.readthedocs.io/en/latest/index.html">
<img src="https://img.shields.io/badge/documentation-online-orange" /></a>
<a alt="pypi" href="https://pypi.org/project/pyrelational/">
Expand Down
3 changes: 1 addition & 2 deletions docs/source/_static/theme.css
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
.wy-nav-content {
max-width: 1200px !important;
min-width: 1200px !important;
min-width: 100% !important;
}
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
copyright = f"{datetime.datetime.now().year}, {author}"

# The full version, including alpha/beta/rc tags
release = "0.1.4"
release = pyrelational.__version__


# -- General configuration ---------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions examples/demo/lightning_diversity_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
train_indices=train_indices,
validation_indices=val_indices,
test_indices=test_indices,
hit_ratio_at=5,
)

strategy = RelativeDistanceStrategy(data_manager=data_manager, model=model)
Expand Down
1 change: 1 addition & 0 deletions examples/demo/lightning_diversity_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
train_indices=train_indices,
validation_indices=val_indices,
test_indices=test_indices,
hit_ratio_at=5,
)

strategy = RelativeDistanceStrategy(data_manager=data_manager, model=model)
Expand Down
3 changes: 1 addition & 2 deletions pyrelational/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
import pyrelational.informativeness
import pyrelational.models
import pyrelational.strategies

__version__ = "0.1.4"
from pyrelational.version import __version__
17 changes: 17 additions & 0 deletions pyrelational/data/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
from typing import Any, Callable, List, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Sampler, Subset

Expand All @@ -19,6 +20,7 @@ def __init__(
validation_indices: Optional[List[int]] = None,
test_indices: Optional[List[int]] = None,
random_label_size: Union[float, int] = 0.1,
hit_ratio_at: Optional[Union[int, float]] = None,
random_seed: int = 1234,
loader_batch_size: Union[int, str] = 1,
loader_shuffle: bool = False,
Expand All @@ -42,6 +44,7 @@ def __init__(
test performance of the model
:param random_label_size: Only used when labelled and unlabelled indices are not provided. Sets the size of
labelled set (should either be the number of samples or ratio w.r.t. train set)
:param hit_ratio_at: optional argument setting the top percentage threshold to compute hit ratio metric
:param random_seed: random seed
:param loader_batch_size: batch size for dataloader
:param loader_shuffle: shuffle flag for dataloader
Expand Down Expand Up @@ -92,6 +95,7 @@ def __init__(
self.random_label_size = random_label_size
self.process_random(random_seed)
self._ensure_no_l_or_u_leaks()
self._top_unlabelled_set(hit_ratio_at)

def _ensure_no_split_leaks(self) -> None:
tt = set.intersection(set(self.train_indices), set(self.test_indices))
Expand Down Expand Up @@ -205,6 +209,19 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
# So that one can access samples by index directly
return self.dataset[idx]

def _top_unlabelled_set(self, percentage: Optional[Union[int, float]] = None) -> None:
if percentage is None:
self.top_unlabelled = None
else:
if isinstance(percentage, int):
percentage /= 100
assert 0 < percentage < 1, "hit ratio's percentage should be strictly between 0 and 1 (or 0 and 100)"
ixs = np.array(self.u_indices)
percentage = int(percentage * len(ixs))
y = torch.stack(self.get_sample_labels(ixs)).squeeze()
threshold = np.sort(y.abs())[-percentage]
self.top_unlabelled = set(ixs[(y.abs() >= threshold).numpy().astype(bool)])

def get_train_set(self) -> Dataset:
train_subset = Subset(self.dataset, self.train_indices)
return train_subset
Expand Down
18 changes: 9 additions & 9 deletions pyrelational/informativeness/task_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def relative_distance(
samples in the reference set.
:param query_set: input containing the features of samples in the queryable pool. query set should either be an
array-like object or a pytorch dataloader whose first element in each bactch is a featurisation of the samples in
the batch.
array-like object or a pytorch dataloader whose first element in each bactch is a featurisation of the samples
in the batch.
:param reference_set: input containing the features of samples already queried samples against which the distances
are computed. reference set should either be an array-like object or a pytorch dataloader whose first element in
each bactch is a featurisation of the samples in the batch.
are computed. reference set should either be an array-like object or a pytorch dataloader whose first element in
each bactch is a featurisation of the samples in the batch.
:param metric: defines the metric to be used to compute the distance. This should be supported by scikit-learn
pairwise_distances function.
pairwise_distances function.
:param axis: integer indicating which dimension the features are
:return: pytorch tensor of dimension the number of samples in query_set containing the minimum distance from each
sample to the reference set
sample to the reference set
"""
if isinstance(query_set, get_args(Array)):
query_set = np.array(query_set)
Expand Down Expand Up @@ -95,12 +95,12 @@ def representative_sampling(
algorithms in scikit-learn.
:param query_set: input containing the features of samples in the queryable pool. query set should either be an
array-like object or a pytorch dataloader whose first element in each bactch is a featurisation of the samples in
the batch.
array-like object or a pytorch dataloader whose first element in each bactch is a featurisation of the samples
in the batch
:param num_annotate: number of representative samples to identify
:param clustering_method: name, or instantiated class, of the clustering method to use
:param clustering_kwargs: arguments to be passed to instantiate clustering class if a string is passed to
clustering_method
clustering_method
:return: array-like containing the indices of the representative samples identified
"""

Expand Down
57 changes: 40 additions & 17 deletions pyrelational/strategies/generic_al_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import defaultdict
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
from tabulate import tabulate
from torch.utils.data import DataLoader
Expand All @@ -22,11 +23,10 @@ class GenericActiveLearningStrategy(ABC):
to define specific acquisition/active learning strategies and hence
drive the active learning process
Args:
data_manager (pyrelational.data_manager.GenericDataManager): an pyrelational data manager
:param data_manager: an pyrelational data manager
which keeps track of what has been labelled and creates data loaders for
active learning
model (pyrelational.models.generic_model.GenericModel): An pyrelational model
:param model: An pyrelational model
which serves as the machine learning model for the data in the
data manager
"""
Expand All @@ -51,21 +51,23 @@ def theoretical_performance(self, test_loader: Optional[DataLoader] = None) -> D
Would not make much sense when we are doing active learning for the real
situation, hence not part of __init__
Args:
test_loader (Optional: torch.utils.data.DataLoader): Pytorch Data Loader with
:param test_loader: Pytorch Data Loader with
test data compatible with model, optional as often the test loader can be
generated from data_manager but is here for case when it hasn't been defined
or there is a new test set.
"""
self.model.train(self.train_loader, self.valid_loader)

# use test loader in data_manager if there is one
self.performances["full"] = self.model.test(self.test_loader if test_loader is None else test_loader)
result = self.model.test(self.test_loader if test_loader is None else test_loader)
if self.data_manager.top_unlabelled is not None:
result["hit_ratio"] = np.nan
self.performances["full"] = result
# make sure that theoretical best model is not stored
self.model.current_model = None
return self.performances["full"]

def current_performance(self, test_loader: Optional[DataLoader] = None) -> Dict:
def current_performance(self, test_loader: Optional[DataLoader] = None, query: Optional[List[int]] = None) -> Dict:
if self.model.current_model is None: # no AL steps taken so far
self.model.train(self.l_loader, self.valid_loader)

Expand All @@ -75,15 +77,21 @@ def current_performance(self, test_loader: Optional[DataLoader] = None) -> Dict:
result = self.model.test(test_loader)
# reset current model, to avoid issues when model does not need to be trained during al_step.
self.model.current_model = None

if self.data_manager.top_unlabelled is not None:
result["hit_ratio"] = (
np.nan
if query is None
else len(set(query) & self.data_manager.top_unlabelled) / len(self.data_manager.top_unlabelled)
)
return result

@abstractmethod
def active_learning_step(self, num_annotate: int) -> List[int]:
"""Implements a single step of the active learning strategy stopping and returning the
unlabelled observations to be labelled as a list of dataset indices
Args:
num_annotate (int > 0): number of observations from u to suggest for labelling
:param num_annotate: number of observations from u to suggest for labelling
"""
pass

Expand All @@ -104,9 +112,11 @@ def active_learning_update(self, indices: List[int], oracle_interface: object =
def full_active_learning_run(
self,
num_annotate: int,
num_iterations: Optional[int] = None,
oracle_interface: object = None,
test_loader: DataLoader = None,
) -> None:
return_query_history: bool = False,
) -> Optional[Dict]:
"""Given the number of samples to annotate and a test loader
this method will go through the entire active learning process of training
the model on the labelled set, and recording the current performance
Expand All @@ -115,30 +125,43 @@ def full_active_learning_run(
labelled to be added to the next iteration's labelled dataset L'. This
process repeats until there are no observations left in the unlabelled set.
Args:
num_annotate (int): number of observations to get annotated per iteration
test_loader (pytorch.utils.data.DataLoader): test data with which we evaluate
the current state of the model given the labelled set L
oracle_interface (None or pyrelational.oracle_interface.OracleInterface)
:param num_annotate: number of observations to get annotated per iteration
:param num_iterations: number of active learning loop to perform
:param oracle_interface: undefined for now, this will be entry point for external oracle later
:param test_loader: test data with which we evaluate the current state of the model given the labelled set L
:param return_query_history: whether to return the history of queries or not
:return: optionally returns a dictionary storing the indices of queries at each iteration
"""

iter_count = 0
if return_query_history:
query_history = {}
while len(self.u_indices) > 0:
iter_count += 1

# Obtain samples for labelling and pass to the oracle interface if supplied
observations_for_labelling = self.active_learning_step(num_annotate)
if return_query_history:
query_history[iter_count] = observations_for_labelling

# Record the current performance
self.performances[self.iteration] = self.current_performance(test_loader=test_loader)
self.performances[self.iteration] = self.current_performance(
test_loader=test_loader,
query=observations_for_labelling,
)

self.active_learning_update(
observations_for_labelling,
oracle_interface,
update_tag=str(self.iteration),
)
if (num_iterations is not None) and iter_count == num_iterations:
break

# Final update the model and check final test performance
self.model.train(self.l_loader, self.valid_loader)
self.performances[self.iteration] = self.current_performance(test_loader=test_loader)
if return_query_history:
return query_history

def update_annotations(self, indices: List[int]) -> None:
self.data_manager.update_train_labels(indices)
Expand Down
1 change: 1 addition & 0 deletions pyrelational/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.1.5"
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

version = {}
with open("pyrelational/version.py") as fp:
exec(fp.read(), version)

setup(
name="pyrelational",
description="Python tool box for quickly implementing active learning strategies",
Expand All @@ -29,7 +33,7 @@
long_description_content_type="text/markdown",
url="https://github.com/RelationRx/pyrelational",
packages=find_packages(),
version="0.1.4",
version=version["__version__"],
setup_requires=setup_requires,
tests_require=tests_require,
install_requires=install_requires,
Expand Down
6 changes: 4 additions & 2 deletions tests/strategies/test_classification_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_performances():


def test_full_active_learning_run():
gdm = get_classification_dataset()
gdm = get_classification_dataset(hit_ratio_at=5)
model = LightningMCDropoutModel(BreastCancerClassifier, {"ensemble_size": 3}, {"epochs": 1})
al_manager = LeastConfidenceStrategy(data_manager=gdm, model=model)
al_manager.theoretical_performance()
Expand All @@ -37,10 +37,12 @@ def test_full_active_learning_run():
# Test performance history data frame
df = al_manager.performance_history()
print(df)
assert df.shape == (3, 2)
assert df.shape == (3, 3)
assert len(al_manager.data_manager.l_indices) == len(gdm.train_indices)
assert len(al_manager.data_manager.u_indices) == 0
assert {"full", 0, 1, 2} == set(list(al_manager.performances.keys()))
for k in {"full", 0, 1, 2}:
assert "hit_ratio" in al_manager.performances[k].keys()


def test_update_annotations():
Expand Down
6 changes: 4 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyrelational.data.data_manager import GenericDataManager


def get_regression_dataset():
def get_regression_dataset(hit_ratio_at=None):
pl.seed_everything(0)

ds = DiabetesDataset()
Expand All @@ -23,10 +23,11 @@ def get_regression_dataset():
validation_indices=valid_indices,
test_indices=test_indices,
loader_batch_size=10,
hit_ratio_at=hit_ratio_at,
)


def get_classification_dataset(labelled_size=None):
def get_classification_dataset(labelled_size=None, hit_ratio_at=None):
pl.seed_everything(0)

ds = BreastCancerDataset()
Expand All @@ -43,6 +44,7 @@ def get_classification_dataset(labelled_size=None):
test_indices=test_indices,
labelled_indices=labelled_indices,
loader_batch_size=10,
hit_ratio_at=hit_ratio_at,
)


Expand Down

0 comments on commit 41edede

Please sign in to comment.