Skip to content

Commit

Permalink
fix: allow creation of MulticlassLabelDataset in create_dataset function
Browse files Browse the repository at this point in the history
  • Loading branch information
liamj2311 committed Jun 13, 2024
1 parent 4dd7668 commit cfdbe62
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion aequitas/core/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def create_dataset(dataset_type,
kwargs['df'] = imputed_df
dataset = _DATASET_TYPES[dataset_type](**kwargs)

if dataset_type == "binary label":
binary_label_metrics = [k for k, v in _DATASET_TYPES.items() if v == BinaryLabelDataset or v == MulticlassLabelDataset]
if dataset_type in binary_label_metrics:
return DatasetWithBinaryLabelMetrics(dataset, unprivileged_groups, privileged_groups)
elif dataset_type == "regression":
return DatasetWithRegressionMetrics(dataset, unprivileged_groups, privileged_groups)
Expand Down

0 comments on commit cfdbe62

Please sign in to comment.