Skip to content

Commit

Permalink
feat: add metrics to MulticlassLabelDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
liamj2311 committed May 13, 2024
1 parent d2d86d6 commit 2abab38
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions aequitas/core/datasets/multi_class_label_dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from aequitas.core.datasets.structured_dataset import StructuredDataset
from aequitas.core.imputation_strategies.imputation_strategy import MissingValuesImputationStrategy
from aequitas.core.metrics.binary_label_dataset_scores_metric import BinaryLabelDatasetScoresMetric
from aif360.metrics import BinaryLabelDatasetMetric

from aif360.datasets.multiclass_label_dataset import MulticlassLabelDataset


class MulticlassLabelDataset(StructuredDataset, MulticlassLabelDataset):

def __init__(self, imputation_strategy: MissingValuesImputationStrategy,
favorable_label, unfavorable_label, **kwargs):

super(MulticlassLabelDataset, self).__init__(imputation_strategy=imputation_strategy, favorable_label=favorable_label,
unfavorable_label=unfavorable_label, **kwargs)
def __init__(self, unprivileged_groups, privileged_groups, **kwargs):
self.kwargs = kwargs
self.unprivileged_groups = unprivileged_groups
self.privileged_groups = privileged_groups
super(MulticlassLabelDataset, self).__init__(**kwargs)

@property
def metrics(self, **kwargs):
return BinaryLabelDatasetScoresMetric(self, **kwargs)
return BinaryLabelDatasetMetric(dataset=self,
unprivileged_groups=self.unprivileged_groups,
privileged_groups=self.privileged_groups)

@property
def scores_metrics(self, **kwargs):
return BinaryLabelDatasetScoresMetric(dataset=self,
unprivileged_groups=self.unprivileged_groups,
privileged_groups=self.privileged_groups)

0 comments on commit 2abab38

Please sign in to comment.