Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary group fairness metrics #1404

Merged
merged 102 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
7db8db3
init fairness
AndresAlgaba Dec 21, 2022
2a383ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2022
fe41f56
fix type
AndresAlgaba Dec 21, 2022
3ec0e6d
Merge branch 'master' into fairness
lucadiliello Dec 23, 2022
05b30f6
Merge branch 'master' into fairness
AndresAlgaba Dec 24, 2022
7458615
improve code speed
AndresAlgaba Dec 26, 2022
0a754c2
improvde code speed
AndresAlgaba Dec 26, 2022
3ef0a17
Merge branch 'master' into fairness
AndresAlgaba Dec 26, 2022
21d689f
add docs
AndresAlgaba Dec 26, 2022
baa996d
Merge branch 'fairness' of https://github.com/AndresAlgaba/metrics in…
AndresAlgaba Dec 26, 2022
cf15608
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2022
5ca45eb
fix docs
AndresAlgaba Dec 27, 2022
5cf8660
Merge branch 'master' into fairness
lucadiliello Jan 3, 2023
66d1ef9
Merge branch 'master' into fairness
AndresAlgaba Jan 4, 2023
5e5f96a
Merge branch 'master' into fairness
AndresAlgaba Jan 9, 2023
9c4ef5c
fix circular import and output of BinaryFairness
AndresAlgaba Jan 9, 2023
4f876c4
fix docs
AndresAlgaba Jan 9, 2023
5e49e8d
keep consistent output with functional
AndresAlgaba Jan 12, 2023
6b3d3bd
fix typing
AndresAlgaba Jan 12, 2023
8eb8918
initialize testing
AndresAlgaba Jan 12, 2023
d7599f6
Merge branch 'master' into fairness
AndresAlgaba Jan 12, 2023
93536f8
fix TPR
AndresAlgaba Jan 12, 2023
2c0f757
Merge branch 'master' into fairness
AndresAlgaba Jan 22, 2023
9b84937
add fairlearn
AndresAlgaba Jan 24, 2023
6fe5be2
Merge branch 'fairness' of https://github.com/AndresAlgaba/metrics in…
AndresAlgaba Jan 24, 2023
f36b502
create dict
AndresAlgaba Jan 24, 2023
0f8e29e
add Dict
AndresAlgaba Jan 24, 2023
4dba340
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2023
426cc29
Merge branch 'master' into fairness
AndresAlgaba Jan 25, 2023
3554ab4
return tensor in fairlearn
AndresAlgaba Jan 25, 2023
e33931e
Merge branch 'fairness' of https://github.com/AndresAlgaba/metrics in…
AndresAlgaba Jan 25, 2023
eee2b67
Merge branch 'master' into fairness
AndresAlgaba Jan 30, 2023
7fc7193
Merge branch 'master' into fairness
AndresAlgaba Jan 31, 2023
b194780
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2023
089f2d4
Merge branch 'master' into fairness
AndresAlgaba Feb 2, 2023
7345384
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2023
56769ec
changelog
SkafteNicki Feb 3, 2023
98883a2
Update src/torchmetrics/classification/group_fairness.py
AndresAlgaba Feb 4, 2023
8172b5e
Update src/torchmetrics/classification/group_fairness.py
AndresAlgaba Feb 4, 2023
73515b0
Update src/torchmetrics/functional/classification/group_fairness.py
AndresAlgaba Feb 4, 2023
c7d545e
Update src/torchmetrics/classification/group_fairness.py
AndresAlgaba Feb 4, 2023
25a6b61
Update src/torchmetrics/classification/group_fairness.py
AndresAlgaba Feb 4, 2023
7188387
Merge branch 'master' into fairness
AndresAlgaba Feb 4, 2023
01f6823
Merge branch 'master' into fairness
AndresAlgaba Feb 6, 2023
bbc9e49
Merge branch 'master' into fairness
Borda Feb 7, 2023
04735ee
Merge branch 'master' into fairness
AndresAlgaba Feb 8, 2023
839bf04
Merge branch 'master' into fairness
AndresAlgaba Feb 13, 2023
37e1835
Merge branch 'master' into fairness
AndresAlgaba Feb 15, 2023
2a1506c
Merge branch 'master' into fairness
AndresAlgaba Feb 17, 2023
85ae325
switch target and groups
AndresAlgaba Feb 17, 2023
1b69cb3
Merge branch 'master' into fairness
AndresAlgaba Feb 17, 2023
a0092a3
Merge branch 'master' into fairness
Borda Feb 18, 2023
645a8bc
Merge branch 'master' into fairness
AndresAlgaba Feb 20, 2023
7fd74f5
Merge branch 'master' into fairness
AndresAlgaba Feb 21, 2023
aa8385e
Merge branch 'master' into fairness
Borda Feb 21, 2023
d817d2f
Merge branch 'master' into fairness
Borda Feb 22, 2023
f3f7f1b
Merge branch 'master' into fairness
lucadiliello Feb 23, 2023
18467b6
Merge branch 'master' into fairness
mergify[bot] Feb 23, 2023
621b24c
uncomment tests
AndresAlgaba Feb 23, 2023
17e67f0
Merge branch 'master' into fairness
mergify[bot] Feb 23, 2023
93bfe08
add links
AndresAlgaba Feb 23, 2023
2eaf932
Merge branch 'fairness' of https://github.com/AndresAlgaba/metrics in…
AndresAlgaba Feb 23, 2023
0da6a85
Merge branch 'master' into fairness
mergify[bot] Feb 23, 2023
6565364
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
147c213
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
cf89ac4
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
3fb2a63
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
df1b549
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
eb9af2d
ruff: first line split + imperative mood (#1548
SkafteNicki Feb 24, 2023
37021d3
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
53566a9
ruff: first line split + imperative mood (#1548)
SkafteNicki Feb 24, 2023
8bd75dc
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
067e74a
Merge branch 'master' into fairness
mergify[bot] Feb 24, 2023
31b032b
Merge branch 'master' into fairness
stancld Feb 25, 2023
3293a17
Fix import in tests
stancld Feb 25, 2023
7623387
Fix tests
stancld Feb 25, 2023
bc47ee0
Fix `D205 1 blank line required between summary line and description`
stancld Feb 25, 2023
e2cdc10
Try to fix tests for oldest configuration
stancld Feb 25, 2023
a42e3d3
Merge branch 'master' into fairness
mergify[bot] Feb 25, 2023
cfd6390
Merge branch 'master' into fairness
mergify[bot] Feb 25, 2023
2d93e13
Merge branch 'master' into fairness
AndresAlgaba Feb 27, 2023
cea4f67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2023
01fd9e9
Merge branch 'master' into fairness
Borda Feb 27, 2023
0df6d10
Merge branch 'master' into fairness
mergify[bot] Feb 27, 2023
5dc53ac
Merge branch 'master' into fairness
mergify[bot] Feb 27, 2023
ebd729c
Merge branch 'master' into fairness
mergify[bot] Feb 27, 2023
d9788c0
Merge branch 'master' into fairness
mergify[bot] Feb 28, 2023
464d59e
Merge branch 'master' into fairness
mergify[bot] Feb 28, 2023
c1c3276
Merge branch 'master' into fairness
mergify[bot] Feb 28, 2023
b327066
Merge branch 'master' into fairness
mergify[bot] Feb 28, 2023
020ea63
Merge branch 'master' into fairness
mergify[bot] Feb 28, 2023
092878f
Merge branch 'master' into fairness
mergify[bot] Feb 28, 2023
79ce813
Merge branch 'master' into fairness
mergify[bot] Mar 1, 2023
3477658
Merge branch 'master' into fairness
mergify[bot] Mar 1, 2023
0516295
Merge branch 'master' into fairness
AndresAlgaba Mar 2, 2023
fe16216
reqs: Add pandas to classification test reqs
stancld Mar 3, 2023
a80d2b9
python: Skip tests for python 3.7 not supported by reference package
stancld Mar 3, 2023
7b91bb5
typo: Skip tests for python 3.7 not supported by reference package
stancld Mar 3, 2023
e74f6f9
Try to fix links for make-docs
stancld Mar 3, 2023
cc1ca50
Merge branch 'master' into fairness
Borda Mar 3, 2023
5be0127
Merge branch 'master' into fairness
stancld Mar 3, 2023
d0f7b17
Merge branch 'master' into fairness
mergify[bot] Mar 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ module = [
"torchmetrics.classification.confusion_matrix",
"torchmetrics.classification.exact_match",
"torchmetrics.classification.f_beta",
"torchmetrics.classification.group_fairness",
"torchmetrics.classification.hamming",
"torchmetrics.classification.hinge",
"torchmetrics.classification.jaccard",
Expand All @@ -76,6 +77,7 @@ module = [
"torchmetrics.functional.classification.calibration_error",
"torchmetrics.functional.classification.confusion_matrix",
"torchmetrics.functional.classification.f_beta",
"torchmetrics.functional.classification.group_fairness",
"torchmetrics.functional.classification.precision_recall_curve",
"torchmetrics.functional.classification.ranking",
"torchmetrics.functional.classification.recall_at_fixed_precision",
Expand Down
4 changes: 4 additions & 0 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
MultilabelF1Score,
MultilabelFBetaScore,
)

# from torchmetrics.classification.group_fairness import BinaryFairness, BinaryGroupStatRates
AndresAlgaba marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.classification.hamming import (
BinaryHammingDistance,
HammingDistance,
Expand Down Expand Up @@ -148,6 +150,8 @@
"MulticlassFBetaScore",
"MultilabelF1Score",
"MultilabelFBetaScore",
# "BinaryFairness",
# "BinaryGroupStatRates",
"BinaryHammingDistance",
"HammingDistance",
"MulticlassHammingDistance",
Expand Down
257 changes: 257 additions & 0 deletions src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union

import torch
from typing_extensions import Literal

from torchmetrics import Metric
from torchmetrics.functional.classification.group_fairness import (
_binary_groups_stat_scores,
_compute_binary_demographic_parity,
_compute_binary_equal_opportunity,
)
from torchmetrics.functional.classification.stat_scores import _binary_stat_scores_arg_validation
from torchmetrics.utilities import rank_zero_warn


class _AbstractGroupStatScores(Metric):
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Create and update states for computing group stats tp, fp, tn and fn."""

def _create_states(self, num_groups: int) -> None:
default = lambda: torch.zeros(num_groups, dtype=torch.long)
self.add_state("tp", default(), dist_reduce_fx="sum")
self.add_state("fp", default(), dist_reduce_fx="sum")
self.add_state("tn", default(), dist_reduce_fx="sum")
self.add_state("fn", default(), dist_reduce_fx="sum")

def _update_states(
self, group_stats: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]
) -> None:
for group, stats in enumerate(group_stats.values()):
tp, fp, tn, fn = stats
self.tp[group] += tp
self.fp[group] += fp
self.tn[group] += tn
self.fn[group] += fn


class BinaryGroupStatRates(_AbstractGroupStatScores):
r"""Computes the true positives, false positives, true negatives, and false negatives rates for binary
classification by group. Related to `Type I and Type II errors`_.

Accepts the following input tensors:
- ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
[0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally,
we convert to int tensor with thresholding using the value in ``threshold``.
- ``target`` (int tensor): ``(N, ...)``.
- ``groups`` (int tensor): ``(N, ...)``. The group identifiers should be ``0, 1, ..., (num_groups - 1)``.

The additional dimensions are flatted along the batch dimension.

Args:
num_groups: The number of groups.
threshold: Threshold for transforming probability to binary {0,1} predictions.
ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Returns:
The metric returns a dict with a group identifier as key and a tensor with the tp, fp, tn and fn rates as value.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryGroupStatRates
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 1, 0, 1, 0, 1])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryGroupStatRates(2)
AndresAlgaba marked this conversation as resolved.
Show resolved Hide resolved
>>> metric(preds, target, groups)
{'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}

Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryGroupStatRates
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryGroupStatRates(2)
AndresAlgaba marked this conversation as resolved.
Show resolved Hide resolved
>>> metric(preds, target, groups)
{'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
"""
is_differentiable = False
higher_is_better = False
full_state_update: bool = False

def __init__(
self,
num_groups: int,
threshold: float = 0.5,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
"""Initialize states and validate arguments."""
super().__init__()

if validate_args:
_binary_stat_scores_arg_validation(threshold, "global", ignore_index)

self.num_groups = num_groups
AndresAlgaba marked this conversation as resolved.
Show resolved Hide resolved
self.threshold = threshold
self.ignore_index = ignore_index
self.validate_args = validate_args

self._create_states(self.num_groups)

def update(self, preds: torch.Tensor, target: torch.Tensor, groups: torch.Tensor) -> None:
"""Update state with predictions, target and group identifiers.

Args:
preds: Tensor with predictions.
target: Tensor with true labels.
groups: Tensor with group identifiers. The group identifiers should be ``0, 1, ..., (num_groups - 1)``.
"""
group_stats = _binary_groups_stat_scores(
preds, target, groups, self.num_groups, self.threshold, self.ignore_index, self.validate_args
)

self._update_states(group_stats)

def compute(
self,
) -> Union[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]:
"""Computes tp, fp, tn and fn rates based on inputs passed in to ``update`` previously."""
results = torch.stack((self.tp, self.fp, self.tn, self.fn), dim=1)

return {f"group_{i}": group / group.sum() for i, group in enumerate(results)}


class BinaryFairness(_AbstractGroupStatScores):
r"""Computes demographic parity and equal opportunity ratio for binary classification problems.
stancld marked this conversation as resolved.
Show resolved Hide resolved

This class computes the ratio between positivity rates and true positives rates for different groups.
If more than two groups are present, the disparity between the lowest and highest group is reported.
A disparity between positivity rates indicates a potential violation of demographic parity, and between
true positive rates indicates a potential violation of equal opportunity.

The lowest rate is divided by the highest, so a lower value means more discrimination against the numerator.
In the results this is also indicated as the key of dict is {metric}_{identifier_low_group}_{identifier_high_group}.

Accepts the following input tensors:
- ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside
[0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally,
we convert to int tensor with thresholding using the value in ``threshold``.
- ``groups`` (int tensor): ``(N, ...)``. The group identifiers should be ``0, 1, ..., (num_groups - 1)``.
- ``target`` (int tensor): ``(N, ...)``.

The additional dimensions are flatted along the batch dimension.

Args:
num_groups: The number of groups.
task: The task to compute. Can be either ``demographic_parity`` or ``equal_oppotunity`` or ``all``.
threshold: Threshold for transforming probability to binary {0,1} predictions.
ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Returns:
The metric returns a dict where the key identifies the metric and groups with the lowest and highest true
positives rates as follows: {metric}__{identifier_low_group}_{identifier_high_group}.
The value is a tensor with the disparity rate.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryFairness
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 1, 0, 1, 0, 1])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryFairness(2)
>>> metric(preds, groups, target)
({'DP_0_1': tensor(0.)}, {'EO_0_1': tensor(0.)})

Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryFairness
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryFairness(2)
>>> metric(preds, groups, target)
({'DP_0_1': tensor(0.)}, {'EO_0_1': tensor(0.)})
"""
is_differentiable = False
higher_is_better = False
full_state_update: bool = False

def __init__(
self,
num_groups: int,
task: Literal["demographic_parity", "equal_opportunity", "all"] = "all",
threshold: float = 0.5,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
"""Initialize states and validate arguments."""
super().__init__()

if task not in ["demographic_parity", "equal_opportunity", "all"]:
raise ValueError(
f"Expected argument `task` to either be ``demographic_parity``,"
f"``equal_opportunity`` or ``all`` but got {task}."
)

if validate_args:
_binary_stat_scores_arg_validation(threshold, "global", ignore_index)

self.num_groups = num_groups
AndresAlgaba marked this conversation as resolved.
Show resolved Hide resolved
self.task = task
self.threshold = threshold
self.ignore_index = ignore_index
self.validate_args = validate_args

self._create_states(self.num_groups)

def update(self, preds: torch.Tensor, groups: torch.Tensor, target: Optional[torch.Tensor] = None) -> None:
"""Update state with predictions, groups, and target.

Args:
preds: Tensor with predictions.
groups: Tensor with group identifiers. The group identifiers should be ``0, 1, ..., (num_groups - 1)``.
target: Tensor with true labels.
"""
if self.task == "demographic_parity":
if target is not None:
rank_zero_warn("The task demographic_parity does not require a target.", UserWarning)
target = torch.zeros(preds.shape)

group_stats = _binary_groups_stat_scores(
preds, target, groups, self.num_groups, self.threshold, self.ignore_index, self.validate_args
)

self._update_states(group_stats)

def compute(
self,
) -> Union[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]]:
AndresAlgaba marked this conversation as resolved.
Show resolved Hide resolved
"""Computes fairness criteria based on inputs passed in to ``update`` previously."""
if self.task == "demographic_parity":
return _compute_binary_demographic_parity(self.tp, self.fp, self.tn, self.fn)

if self.task == "equal_opportunity":
return _compute_binary_equal_opportunity(self.tp, self.fp, self.tn, self.fn)

if self.task == "all":
return _compute_binary_demographic_parity(
self.tp, self.fp, self.tn, self.fn
), _compute_binary_equal_opportunity(self.tp, self.fp, self.tn, self.fn)
6 changes: 6 additions & 0 deletions src/torchmetrics/functional/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
multilabel_f1_score,
multilabel_fbeta_score,
)
from torchmetrics.functional.classification.group_fairness import ( # noqa: F401
binary_fairness,
binary_groups_stat_rates,
demographic_parity,
equal_opportunity,
)
from torchmetrics.functional.classification.hamming import ( # noqa: F401
binary_hamming_distance,
hamming_distance,
Expand Down
Loading