Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AL-Pipeline #33

Merged
merged 28 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a95efc6
Linted
paulmorio Dec 3, 2022
e3af496
Updating strategy implementation with filter_kwargs pattern
paulmorio Dec 3, 2022
c38379a
GenericPipeline with GenericOracle
paulmorio Dec 4, 2022
e500692
_filter_kwargs with default pipeline attributes being used and update…
paulmorio Dec 4, 2022
e4b9e39
Example of changes required on strategies for classification with wor…
paulmorio Dec 4, 2022
a33ef98
Updated classification strategies to new pipeline API
paulmorio Dec 4, 2022
f59c41c
Updated tests for classifications strategies
paulmorio Dec 4, 2022
f65407f
Updated regression strategies to use new pipeline API
paulmorio Dec 4, 2022
8fc1e9e
Updated regression strategy tests to use new pipeline API
paulmorio Dec 4, 2022
3850b6d
Updated task agnostic strategies to use new pipeline API changes
paulmorio Dec 4, 2022
87b1a4f
Updated task agnostic strategy tests to reflect pipeline API change
paulmorio Dec 4, 2022
79dff74
Removed redundant processing of input kwargs
paulmorio Dec 12, 2022
03c5951
Removed model reset in computing current performance
paulmorio Dec 13, 2022
99442f4
Update pyrelational/pipeline/generic_pipeline.py
paulmorio Dec 13, 2022
4063781
Compute hit ratio
paulmorio Dec 13, 2022
c4a2001
Updated data-manager to set values for targets and the oracle with a …
paulmorio Dec 13, 2022
b0d7a61
Updated interface to the oracle
paulmorio Dec 13, 2022
b81c244
Abstracted away training and inferring in strategies
paulmorio Dec 13, 2022
7f1c87d
Updated oracle with interface and concrete dummy oracle
paulmorio Dec 14, 2022
fa0e3ef
Updated scikit_estimator example with dummy oracle
paulmorio Dec 14, 2022
cc9da1a
Updated test to new pipeline and dummy oracle, including note on depr…
paulmorio Dec 14, 2022
9f82b19
Combined update annotation logic into oracle
paulmorio Dec 14, 2022
541a1ac
Ensure data folder is not tracked if created using the examples.
paulmorio Dec 14, 2022
9a8d7c3
Updated name of dummy oracle -> benchmark oracle
paulmorio Dec 14, 2022
fb6b541
Addressing comments
paulmorio Dec 14, 2022
7e8b707
Addressing linter issue not caught in precommit
paulmorio Dec 14, 2022
de77778
Fix linter issue breaking test for the default oracle instantiation
paulmorio Dec 14, 2022
f492bf7
Merge branch 'main' into al-pipeline
paulmorio Jan 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
data/
.idea

# Dev files
Expand Down
28 changes: 19 additions & 9 deletions examples/demo/scikit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset

# Data and data manager
from examples.utils.datasets import BreastCancerDataset # noqa: E402
from pyrelational.data import GenericDataManager

# pyrelational
# Model, strategy, oracle, and pipeline
from pyrelational.models import GenericModel
from pyrelational.oracle import DummyOracle
from pyrelational.pipeline.generic_pipeline import GenericPipeline
from pyrelational.strategies.classification import LeastConfidenceStrategy


Expand Down Expand Up @@ -81,16 +84,23 @@ def __call__(self, loader):
trainer_config = {}
model = SKRFC(RandomForestClassifier, model_config, trainer_config)

# Run active learning strategy
al_strategy = LeastConfidenceStrategy(data_manager, model)
# Instantiate an active learning strategy
al_strategy = LeastConfidenceStrategy()

# performance with the full trainset labelled
al_strategy.theoretical_performance()
# Instantiate an oracle (in this case a dummy one)
oracle = DummyOracle()

# Given that we have a data manager, a model, and an active learning strategy
# we may create an active learning pipeline
pipeline = GenericPipeline(data_manager=data_manager, model=model, strategy=al_strategy, oracle=oracle)

# theoretical performance if the full trainset is labelled
pipeline.theoretical_performance()

# New data to be annotated, followed by an update of the data_manager and model
to_annotate = al_strategy.active_learning_step(num_annotate=100)
al_strategy.active_learning_update(to_annotate, oracle_interface=None, update_tag="Manual Update")
to_annotate = pipeline.active_learning_step(num_annotate=100)
pipeline.active_learning_update(indices=to_annotate, update_tag=f"Manual Update with {str(al_strategy)}")

# Annotating data step by step until the trainset is fully annotated
al_strategy.full_active_learning_run(num_annotate=100)
print(al_strategy)
pipeline.full_active_learning_run(num_annotate=20)
print(pipeline)
1 change: 1 addition & 0 deletions pyrelational/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pyrelational.data
import pyrelational.informativeness
import pyrelational.models
import pyrelational.pipeline
import pyrelational.strategies
from pyrelational.version import __version__
12 changes: 12 additions & 0 deletions pyrelational/data/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor]:
# So that one can access samples by index directly
return self.dataset[idx]

def set_target_value(self, idx: int, value: Any) -> None:
"""Sets a value to the y value of the corresponding observation
denoted by idx in the underlying dataset with the supplied value

:param idx: index value to the observation
:param value: new value for the observation
"""
if hasattr(self.dataset, "y"):
self.dataset.y[idx] = value
if hasattr(self.dataset, "targets"):
self.dataset.targets[idx] = value

def _top_unlabelled_set(self, percentage: Optional[Union[int, float]] = None) -> None:
"""
Sets the top unlabelled indices according to the value of their labels.
Expand Down
2 changes: 2 additions & 0 deletions pyrelational/oracle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pyrelational.oracle.benchmark_oracle import BenchmarkOracle
from pyrelational.oracle.generic_oracle import GenericOracle
28 changes: 28 additions & 0 deletions pyrelational/oracle/benchmark_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

from pyrelational.data.data_manager import GenericDataManager

from .generic_oracle import GenericOracle


class BenchmarkOracle(GenericOracle):
"""An dummy oracle designed for evaluating strategies in R&D settings,
it assumes that all of the observations are sufficiently annotated and
returns those annotations when queried.
"""

def __init__(self):
super(BenchmarkOracle, self).__init__()

def query_target_value(self, data_manager: GenericDataManager, idx: int) -> Any:
"""Default method is to simply return the target in the dataset

:param data_manager: reference to the data_manager which will load the observation if necessary
:param idx: index to observation which we want to query an annotation

:return: the output of the oracle (the target value already in the dataset)
"""
target_value = data_manager.get_sample(idx)[1]
return target_value
74 changes: 74 additions & 0 deletions pyrelational/oracle/generic_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
This file contains the implementation of a generic oracle interface for PyRelationAL
"""

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

from pyrelational.data.data_manager import GenericDataManager


class GenericOracle(ABC):
"""An abstract class acting as an interface for implementing concrete oracles
that can interact with a pyrelational pipeline"""

def __init__(self):
super(GenericOracle, self).__init__()

def update_target_value(self, data_manager: GenericDataManager, idx: int, value: Any) -> None:
"""Update the target value for the observation denoted by the index

:param data_manager: reference to the data_manager whose dataset we want to update
:param idx: index to the observation we want to update
:param value: value to update the observation with
"""
data_manager.set_target_value(idx=idx, value=value)

def update_target_values(self, data_manager: GenericDataManager, indices: List[int], values: List[Any]) -> None:
"""Updates the target values of the observations at the supplied indices

:param data_manager: reference to the data_manager whose dataset we want to update
:param indices: list of indices to observations whose target values we want to update
:param values: list of values which we want to assign to the corresponding observations in indices
"""
for idx, val in zip(indices, values):
data_manager.set_target_value(idx=idx, value=val)

def update_annotations(self, data_manager: GenericDataManager, indices: List[int]) -> None:
"""Calls upon the data_manager to update the set of labelled indices with those supplied
as arguments. It will move the observations associated with the supplied indices from the
unlabelled set to the labelled set. By default any indices supplied that are already in
the labelled set are untouched.

Note this does not change the target values of the indices, this is handled by a method
in the oracle.

:param data_manager: reference to the data_manager whose sets we are adjusting
:param indices: list of indices selected for labelling
"""
data_manager.update_train_labels(indices)

@abstractmethod
def query_target_value(self, data_manager: GenericDataManager, idx: int) -> Any:
"""Method that needs to be overridden to obtain the annotations for the input index

:param data_manager: reference to the data_manager which will load the observation if necessary
:param idx: index to observation which we want to query an annotation

:return: the output of the oracle
"""
pass

def update_dataset(self, data_manager: GenericDataManager, indices: List[int]) -> None:
"""
This method serves to obtain labels for the supplied indices and update the
target values in the corresponding observations of the data manager

:param data_manager: reference to DataManager whose dataset we intend to update
:param indices: list of indices to observations we want updated
"""
for idx in indices:
target_val = self.query_target_value(data_manager=data_manager, idx=idx)
self.update_target_value(data_manager=data_manager, idx=idx, value=target_val)
self.update_annotations(data_manager=data_manager, indices=indices)
1 change: 1 addition & 0 deletions pyrelational/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pyrelational.pipeline.generic_pipeline import GenericPipeline
Loading