Skip to content

Commit

Permalink
test: test adult preproc function
Browse files Browse the repository at this point in the history
  • Loading branch information
liamj2311 committed May 16, 2024
1 parent dcd2553 commit f8fce95
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions test/core/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from aequitas.core.metrics import BinaryLabelDatasetScoresMetric

from aequitas.core.algorithms import create_algorithm
from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult

from aif360.datasets import AdultDataset

from aequitas.core.algorithms.preprocessing.optim_preproc_helpers import load_preproc_data_adult_aeq

class TestBinaryLabelDataset(unittest.TestCase):

Expand Down Expand Up @@ -334,11 +334,12 @@ class TestMitigationAlgorithms(unittest.TestCase):
def test_disparate_impact_remover_on_adult_dataset(self):
protected = "sex"
ds = create_dataset("adult",
# parameters of aif360.datasets.AdultDataset
unprivileged_groups=[{protected: 0}],
privileged_groups=[{protected: 1}],
protected_attribute_names=[protected],
privileged_classes=[['Male']], categorical_features=[],
features_to_keep=['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
)
privileged_classes=[['Male']],
categorical_features=[],
features_to_keep=['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week'])

scaler = MinMaxScaler(copy=False)

Expand Down Expand Up @@ -370,12 +371,9 @@ def test_disparate_impact_remover_on_adult_dataset(self):

def test_reweighing_on_adult_dataset(self):
protected = "sex"
ds = create_dataset("adult",
# parameters of aif360.datasets.AdultDataset
protected_attribute_names=[protected],
privileged_classes=[['Male']], categorical_features=[],
features_to_keep=['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
)
ds = load_preproc_data_adult_aeq(unprivileged_groups=[{protected: 0}],
privileged_groups=[{protected: 1}],
protected_attributes=[protected])
print(
f"Difference in mean outcomes between unprivileged and privileged groups before reweighing: {ds.metrics.mean_difference()}")
rw = create_algorithm("reweighing", unprivileged_groups=ds.unprivileged_groups,
Expand All @@ -384,6 +382,10 @@ def test_reweighing_on_adult_dataset(self):
print(
f"Difference in mean outcomes between unprivileged and privileged groups after reweighing: {repaired_ds.metrics.mean_difference()}")

def test_adversarial_debiasing_on_adult_dataset(self):
ds = load_preproc_data_adult()



if __name__ == '__main__':
unittest.main()

0 comments on commit f8fce95

Please sign in to comment.