generated from aequitas-aod/template-python-project-poetry
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
95 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,24 @@ | ||
import aequitas | ||
|
||
from .structured_dataset import StructuredDataset | ||
from .create_dataset import CreateDataset | ||
from .binary_label_dataset import BinaryLabelDataset | ||
from .multi_class_label_dataset import MultiClassLabelDataset | ||
|
||
|
||
_DATASET_TYPES = { | ||
"binary label": BinaryLabelDataset, | ||
"multi class": MultiClassLabelDataset, | ||
"binary": BinaryLabelDataset, | ||
"multiclass": MultiClassLabelDataset, | ||
} | ||
|
||
|
||
def create_dataset(dataset_type, **kwargs): | ||
dataset_type = dataset_type.lower() | ||
if dataset_type not in _DATASET_TYPES: | ||
raise ValueError(f"Unknown dataset type: {dataset_type}") | ||
return _DATASET_TYPES[dataset_type](**kwargs) | ||
|
||
|
||
# keep this line at the bottom of this file | ||
aequitas.logger.debug("Module %s correctly loaded", __name__) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from aequitas.core.datasets.structured_dataset import StructuredDataset | ||
from aequitas.core.imputation_strategies.imputation_strategy import MissingValuesImputationStrategy | ||
from aequitas.core.metrics.binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric | ||
import aif360.datasets as datasets | ||
|
||
|
||
class MultiClassLabelDataset(StructuredDataset, datasets.MultiClassLabelDataset): | ||
|
||
def __init__(self, imputation_strategy: MissingValuesImputationStrategy, | ||
favorable_label, unfavorable_label, **kwargs): | ||
|
||
super(MultiClassLabelDataset, self).__init__(imputation_strategy=imputation_strategy, favorable_label=favorable_label, | ||
unfavorable_label=unfavorable_label, **kwargs) | ||
|
||
@property | ||
def metrics(self, **kwargs): | ||
return BinaryLabelDatasetScoresMetric(self, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
import aequitas | ||
|
||
|
||
from .imputation_strategy import MissingValuesImputationStrategy | ||
from .mcmc_imputation_strategy import MCMCImputationStrategy | ||
|
||
# keep this line at the bottom of this file | ||
aequitas.logger.debug("Module %s correctly loaded", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
import aequitas | ||
|
||
|
||
from .binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric | ||
|
||
|
||
# keep this line at the bottom of this file | ||
aequitas.logger.debug("Module %s correctly loaded", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,47 @@ | ||
import unittest | ||
from test import binary_label_dataset | ||
from test import binary_label_dataset_with_scores | ||
|
||
from aequitas.core.datasets.create_dataset import CreateDataset | ||
from aequitas.core.imputation_strategies.mcmc_imputation_strategy import MCMCImputationStrategy | ||
from aequitas.core.metrics.binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric | ||
|
||
|
||
class TestClasses(unittest.TestCase): | ||
|
||
def test_factory_pattern_bld(self): | ||
df = binary_label_dataset() | ||
strategy = MCMCImputationStrategy() | ||
cd = CreateDataset(dataset_type="binary label") | ||
ds = cd.create_dataset(imputation_strategy=strategy, df=df, label_names=['label'], | ||
protected_attribute_names=['feat']) | ||
self.assertTrue(ds is not None) | ||
|
||
def test_factory_pattern_bld_scores(self): | ||
df = binary_label_dataset_with_scores() | ||
strategy = MCMCImputationStrategy() | ||
cd = CreateDataset(dataset_type="binary label") | ||
ds = cd.create_dataset(imputation_strategy=strategy, df=df, label_names=['label'], | ||
protected_attribute_names=['feat'], scores_names=["scores"]) | ||
self.assertTrue(ds is not None) | ||
|
||
def test_metric(self): | ||
df = binary_label_dataset_with_scores() | ||
strategy = MCMCImputationStrategy() | ||
cd = CreateDataset(dataset_type="binary label") | ||
ds = cd.create_dataset(imputation_strategy=strategy, df=df, label_names=['label'], | ||
protected_attribute_names=['feat'], scores_names=["scores"]) | ||
x = ds.metrics.scores_metric() | ||
self.assertTrue(ds is not None) | ||
self.assertTrue(x is not None) | ||
from test import generate_binary_label_dataframe | ||
from test import generate_binary_label_dataframe_with_scores | ||
|
||
from aequitas.core.datasets import create_dataset, BinaryLabelDataset | ||
from aequitas.core.imputation_strategies import MCMCImputationStrategy | ||
from aequitas.core.metrics import BinaryLabelDatasetScoresMetric | ||
|
||
|
||
class TestBinaryLabelDataset(unittest.TestCase): | ||
|
||
def test_dataset_creation_via_factory(self): | ||
ds = create_dataset("binary label", | ||
imputation_strategy=MCMCImputationStrategy(), | ||
df=generate_binary_label_dataframe(), | ||
label_names=['label'], | ||
protected_attribute_names=['feat']) | ||
self.assertIsInstance(ds, BinaryLabelDataset) | ||
self.assertIsNotNone(ds) | ||
|
||
def test_dataset_creation_with_scores_via_factory(self): | ||
ds = create_dataset("binary label", | ||
imputation_strategy=MCMCImputationStrategy(), | ||
df=generate_binary_label_dataframe_with_scores(), | ||
label_names=['label'], | ||
protected_attribute_names=['feat'], | ||
scores_names=["scores"]) | ||
self.assertIsInstance(ds, BinaryLabelDataset) | ||
self.assertIsNotNone(ds) | ||
|
||
def test_metrics_on_dataset(self): | ||
ds = create_dataset("binary label", | ||
imputation_strategy=MCMCImputationStrategy(), | ||
df=generate_binary_label_dataframe_with_scores(), | ||
label_names=['label'], | ||
protected_attribute_names=['feat'], | ||
scores_names=["scores"]) | ||
self.assertIsInstance(ds.metrics, BinaryLabelDatasetScoresMetric) | ||
self.assertIsNotNone(ds) | ||
score = ds.metrics.new_fancy_metric() | ||
self.assertIsNotNone(score) | ||
score = ds.metrics.disparate_impact() | ||
self.assertIsNotNone(score) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |