Skip to content

Commit

Permalink
fix: code-review and api
Browse files Browse the repository at this point in the history
  • Loading branch information
gciatto committed Mar 22, 2024
1 parent 9a7e657 commit fbcb4dd
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 83 deletions.
19 changes: 18 additions & 1 deletion aequitas/core/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
import aequitas

from .structured_dataset import StructuredDataset
from .create_dataset import CreateDataset
from .binary_label_dataset import BinaryLabelDataset
from .multi_class_label_dataset import MultiClassLabelDataset


_DATASET_TYPES = {
"binary label": BinaryLabelDataset,
"multi class": MultiClassLabelDataset,
"binary": BinaryLabelDataset,
"multiclass": MultiClassLabelDataset,
}


def create_dataset(dataset_type, **kwargs):
dataset_type = dataset_type.lower()
if dataset_type not in _DATASET_TYPES:
raise ValueError(f"Unknown dataset type: {dataset_type}")
return _DATASET_TYPES[dataset_type](**kwargs)


# keep this line at the bottom of this file
aequitas.logger.debug("Module %s correctly loaded", __name__)
11 changes: 0 additions & 11 deletions aequitas/core/datasets/concrete_dataset_factories.py

This file was deleted.

12 changes: 0 additions & 12 deletions aequitas/core/datasets/create_dataset.py

This file was deleted.

10 changes: 0 additions & 10 deletions aequitas/core/datasets/dataset_factory.py

This file was deleted.

17 changes: 17 additions & 0 deletions aequitas/core/datasets/multi_class_label_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from aequitas.core.datasets.structured_dataset import StructuredDataset
from aequitas.core.imputation_strategies.imputation_strategy import MissingValuesImputationStrategy
from aequitas.core.metrics.binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric
import aif360.datasets as datasets


class MultiClassLabelDataset(StructuredDataset, datasets.MultiClassLabelDataset):

def __init__(self, imputation_strategy: MissingValuesImputationStrategy,
favorable_label, unfavorable_label, **kwargs):

super(MultiClassLabelDataset, self).__init__(imputation_strategy=imputation_strategy, favorable_label=favorable_label,
unfavorable_label=unfavorable_label, **kwargs)

@property
def metrics(self, **kwargs):
return BinaryLabelDatasetScoresMetric(self, **kwargs)
12 changes: 5 additions & 7 deletions aequitas/core/datasets/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
class StructuredDataset(datasets.StructuredDataset, ABC):
def __init__(self,
imputation_strategy: MissingValuesImputationStrategy,
**kwargs: object,
):
self.__strategy = imputation_strategy
self._df = kwargs.get('df')
self._df = self.__strategy.custom_preprocessing(df=self._df)
kwargs["df"] = self._df
super(StructuredDataset, self).__init__(**kwargs)
**kwargs):
df = kwargs.get('df')
df = imputation_strategy.custom_preprocessing(df=df)
kwargs["df"] = df
super().__init__(**kwargs)

@property
@abstractmethod
Expand Down
4 changes: 4 additions & 0 deletions aequitas/core/imputation_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import aequitas


from .imputation_strategy import MissingValuesImputationStrategy
from .mcmc_imputation_strategy import MCMCImputationStrategy

# keep this line at the bottom of this file
aequitas.logger.debug("Module %s correctly loaded", __name__)
4 changes: 4 additions & 0 deletions aequitas/core/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import aequitas


from .binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric


# keep this line at the bottom of this file
aequitas.logger.debug("Module %s correctly loaded", __name__)
7 changes: 2 additions & 5 deletions aequitas/core/metrics/binary_label_dataset_scores_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@

class BinaryLabelDatasetScoresMetric(metrics.BinaryLabelDatasetMetric):
def __init__(self, dataset, **kwargs):
self._scores = dataset.scores
if self._scores is None:
raise TypeError("Must provide a numpy array representing the score associated with each sample")
super(BinaryLabelDatasetScoresMetric, self).__init__(dataset, **kwargs)

def scores_metric(self):
def new_fancy_metric(self):
# TODO: change name and behaviour
return self._scores.mean()
return self.disparate_impact()
4 changes: 2 additions & 2 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import pandas as pd


def binary_label_dataset(rows: int = 1000) -> pd.DataFrame:
def generate_binary_label_dataframe(rows: int = 1000) -> pd.DataFrame:
features = random.uniform(0, 1, size=(rows, 1))
labels = random.uniform(0, 1, size=(rows, 1)).astype(int)
data = np.concatenate([features] + [labels], axis=1)
return pd.DataFrame(data, columns=['feat', 'label'])

def binary_label_dataset_with_scores(rows: int = 1000) -> pd.DataFrame:
def generate_binary_label_dataframe_with_scores(rows: int = 1000) -> pd.DataFrame:
features = random.uniform(0, 1, size=(rows, 1))
scores = random.uniform(0, 1, size=(rows, 1))
labels = (scores > 0.5).astype(int)
Expand Down
78 changes: 43 additions & 35 deletions test/core/test_classes.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,47 @@
import unittest
from test import binary_label_dataset
from test import binary_label_dataset_with_scores

from aequitas.core.datasets.create_dataset import CreateDataset
from aequitas.core.imputation_strategies.mcmc_imputation_strategy import MCMCImputationStrategy
from aequitas.core.metrics.binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric


class TestClasses(unittest.TestCase):

def test_factory_pattern_bld(self):
df = binary_label_dataset()
strategy = MCMCImputationStrategy()
cd = CreateDataset(dataset_type="binary label")
ds = cd.create_dataset(imputation_strategy=strategy, df=df, label_names=['label'],
protected_attribute_names=['feat'])
self.assertTrue(ds is not None)

def test_factory_pattern_bld_scores(self):
df = binary_label_dataset_with_scores()
strategy = MCMCImputationStrategy()
cd = CreateDataset(dataset_type="binary label")
ds = cd.create_dataset(imputation_strategy=strategy, df=df, label_names=['label'],
protected_attribute_names=['feat'], scores_names=["scores"])
self.assertTrue(ds is not None)

def test_metric(self):
df = binary_label_dataset_with_scores()
strategy = MCMCImputationStrategy()
cd = CreateDataset(dataset_type="binary label")
ds = cd.create_dataset(imputation_strategy=strategy, df=df, label_names=['label'],
protected_attribute_names=['feat'], scores_names=["scores"])
x = ds.metrics.scores_metric()
self.assertTrue(ds is not None)
self.assertTrue(x is not None)
from test import generate_binary_label_dataframe
from test import generate_binary_label_dataframe_with_scores

from aequitas.core.datasets import create_dataset, BinaryLabelDataset
from aequitas.core.imputation_strategies import MCMCImputationStrategy
from aequitas.core.metrics import BinaryLabelDatasetScoresMetric


class TestBinaryLabelDataset(unittest.TestCase):

def test_dataset_creation_via_factory(self):
ds = create_dataset("binary label",
imputation_strategy=MCMCImputationStrategy(),
df=generate_binary_label_dataframe(),
label_names=['label'],
protected_attribute_names=['feat'])
self.assertIsInstance(ds, BinaryLabelDataset)
self.assertIsNotNone(ds)

def test_dataset_creation_with_scores_via_factory(self):
ds = create_dataset("binary label",
imputation_strategy=MCMCImputationStrategy(),
df=generate_binary_label_dataframe_with_scores(),
label_names=['label'],
protected_attribute_names=['feat'],
scores_names=["scores"])
self.assertIsInstance(ds, BinaryLabelDataset)
self.assertIsNotNone(ds)

def test_metrics_on_dataset(self):
ds = create_dataset("binary label",
imputation_strategy=MCMCImputationStrategy(),
df=generate_binary_label_dataframe_with_scores(),
label_names=['label'],
protected_attribute_names=['feat'],
scores_names=["scores"])
self.assertIsInstance(ds.metrics, BinaryLabelDatasetScoresMetric)
self.assertIsNotNone(ds)
score = ds.metrics.new_fancy_metric()
self.assertIsNotNone(score)
score = ds.metrics.disparate_impact()
self.assertIsNotNone(score)


if __name__ == '__main__':
unittest.main()

0 comments on commit fbcb4dd

Please sign in to comment.