Skip to content

Commit

Permalink
fix: dataset creation
Browse files Browse the repository at this point in the history
  • Loading branch information
liamj2311 committed Mar 25, 2024
1 parent 1e0533f commit a74060f
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 29 deletions.
8 changes: 4 additions & 4 deletions aequitas/core/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from .structured_dataset import StructuredDataset
from .binary_label_dataset import BinaryLabelDataset
from .multi_class_label_dataset import MultiClassLabelDataset
from .multi_class_label_dataset import MulticlassLabelDataset


_DATASET_TYPES = {
"binary label": BinaryLabelDataset,
"multi class": MultiClassLabelDataset,
"multi class": MulticlassLabelDataset,
"binary": BinaryLabelDataset,
"multiclass": MultiClassLabelDataset,
"multiclass": MulticlassLabelDataset,
}


Expand All @@ -21,4 +21,4 @@ def create_dataset(dataset_type, **kwargs):


# keep this line at the bottom of this file
aequitas.logger.debug("Module %s correctly loaded", __name__)
aequitas.logger.debug("Module %s correctly loaded", __name__)
14 changes: 6 additions & 8 deletions aequitas/core/datasets/binary_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

class BinaryLabelDataset(StructuredDataset, datasets.BinaryLabelDataset):

def __init__(self, imputation_strategy: MissingValuesImputationStrategy,
favorable_label, unfavorable_label, label_names, protected_attribute_names, **kwargs):

super(BinaryLabelDataset, self).__init__(imputation_strategy=imputation_strategy, favorable_label=favorable_label,
unfavorable_label=unfavorable_label, label_names=label_names,
protected_attribute_names=protected_attribute_names, **kwargs)
def __init__(self, **kwargs):
self.params = kwargs
super(BinaryLabelDataset, self).__init__(**kwargs)

@property
def metrics(self, **kwargs):
return BinaryLabelDatasetScoresMetric(self, **kwargs)
def metrics(self):
dataset = BinaryLabelDataset(**self.params)
return BinaryLabelDatasetScoresMetric(dataset=dataset)
6 changes: 3 additions & 3 deletions aequitas/core/datasets/multi_class_label_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from aequitas.core.datasets.structured_dataset import StructuredDataset
from aequitas.core.imputation_strategies.imputation_strategy import MissingValuesImputationStrategy
from aequitas.core.metrics.binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric
import aif360.datasets as datasets
from aif360.datasets.multiclass_label_dataset import MulticlassLabelDataset


class MultiClassLabelDataset(StructuredDataset, datasets.MultiClassLabelDataset):
class MulticlassLabelDataset(StructuredDataset, MulticlassLabelDataset):

def __init__(self, imputation_strategy: MissingValuesImputationStrategy,
favorable_label, unfavorable_label, **kwargs):

super(MultiClassLabelDataset, self).__init__(imputation_strategy=imputation_strategy, favorable_label=favorable_label,
super(MulticlassLabelDataset, self).__init__(imputation_strategy=imputation_strategy, favorable_label=favorable_label,
unfavorable_label=unfavorable_label, **kwargs)

@property
Expand Down
4 changes: 2 additions & 2 deletions aequitas/core/metrics/binary_label_dataset_scores_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@


class BinaryLabelDatasetScoresMetric(metrics.BinaryLabelDatasetMetric):
def __init__(self, dataset, **kwargs):
super(BinaryLabelDatasetScoresMetric, self).__init__(dataset, **kwargs)
def __init__(self, **kwargs):
super(BinaryLabelDatasetScoresMetric, self).__init__(**kwargs)

def new_fancy_metric(self):
# TODO: change name and behaviour
Expand Down
40 changes: 28 additions & 12 deletions test/core/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,47 @@
class TestBinaryLabelDataset(unittest.TestCase):

def test_dataset_creation_via_factory(self):
ds = create_dataset("binary label",
imputation_strategy=MCMCImputationStrategy(),
df=generate_binary_label_dataframe(),
ds = create_dataset("binary label",
# parameters of aif360.StructuredDataset init
df=generate_binary_label_dataframe(),
label_names=['label'],
protected_attribute_names=['feat'])
protected_attribute_names=['feat'],
# parameters of aequitas.StructuredDataset init
imputation_strategy=MCMCImputationStrategy(),
# parameters of aif360.BinaryLabelDataset init
favorable_label=1,
unfavorable_label=0
)
self.assertIsInstance(ds, BinaryLabelDataset)
self.assertIsNotNone(ds)

def test_dataset_creation_with_scores_via_factory(self):
ds = create_dataset("binary label",
imputation_strategy=MCMCImputationStrategy(),
df=generate_binary_label_dataframe_with_scores(),
# parameters of aif360.StructuredDataset init
df=generate_binary_label_dataframe_with_scores(),
label_names=['label'],
protected_attribute_names=['feat'],
scores_names=["scores"])
protected_attribute_names=['feat'],
scores_names=['scores'],
# parameters of aequitas.StructuredDataset init
imputation_strategy=MCMCImputationStrategy(),
# parameters of aif360.BinaryLabelDataset init
favorable_label=1,
unfavorable_label=0)
self.assertIsInstance(ds, BinaryLabelDataset)
self.assertIsNotNone(ds)

def test_metrics_on_dataset(self):
ds = create_dataset("binary label",
imputation_strategy=MCMCImputationStrategy(),
df=generate_binary_label_dataframe_with_scores(),
# parameters of aif360.StructuredDataset init
df=generate_binary_label_dataframe_with_scores(),
label_names=['label'],
protected_attribute_names=['feat'],
scores_names=["scores"])
protected_attribute_names=['feat'],
scores_names=['scores'],
# parameters of aequitas.StructuredDataset init
imputation_strategy=MCMCImputationStrategy(),
# parameters of aif360.BinaryLabelDataset init
favorable_label=1,
unfavorable_label=0)
self.assertIsInstance(ds.metrics, BinaryLabelDatasetScoresMetric)
self.assertIsNotNone(ds)
score = ds.metrics.new_fancy_metric()
Expand Down

0 comments on commit a74060f

Please sign in to comment.