Skip to content

Commit

Permalink
Merge branch 'master' into fix/collection_kwargs_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jan 4, 2022
2 parents c5e084f + 2e58596 commit bec6b79
Show file tree
Hide file tree
Showing 17 changed files with 191 additions and 161 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@master
- uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8

# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
Expand Down Expand Up @@ -47,7 +47,7 @@ jobs:
- uses: actions/checkout@master
- uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8

# Note: This uses an internal pip API and may not always work
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ formats: all

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
version: 3.8
install:
- requirements: requirements/docs.txt
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Metrics having third party dependencies removed from global import ([#463](https://github.com/PyTorchLightning/metrics/pull/463))


- `BLEUScore` now expects untokenized input to stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))
- Untokenized for `BLEUScore` input stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640))


- Arguments reordered for `TER`, `BLEUScore`, `SacreBLEUScore`, `CHRFScore` now expect input order as predictions first and target second ([#696](https://github.com/PyTorchLightning/metrics/pull/696))


- Renamed `torchmetrics.collections` to `torchmetrics.metrics_collections` to avoid clashing with system's `collections` package ([#695](https://github.com/PyTorchLightning/metrics/pull/695))
Expand Down
16 changes: 8 additions & 8 deletions tests/text/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
smooth_func = SmoothingFunction().method2


def _compute_bleu_metric_nltk(list_of_references, hypotheses, weights, smoothing_function, **kwargs):
def _compute_bleu_metric_nltk(hypotheses, list_of_references, weights, smoothing_function, **kwargs):
hypotheses_ = [hypothesis.split() for hypothesis in hypotheses]
list_of_references_ = [[line.split() for line in ref] for ref in list_of_references]
return corpus_bleu(
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights,
sk_metric=compute_bleu_metric_nltk,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth):
Expand All @@ -80,7 +80,7 @@ def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_fun
metric_functional=bleu_score,
sk_metric=compute_bleu_metric_nltk,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smooth_func, smooth):
Expand All @@ -92,31 +92,31 @@ def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smo
metric_module=BLEUScore,
metric_functional=bleu_score,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)


def test_bleu_empty_functional():
hyp = [[]]
ref = [[[]]]
assert bleu_score(ref, hyp) == tensor(0.0)
assert bleu_score(hyp, ref) == tensor(0.0)


def test_no_4_gram_functional():
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu_score(refs, hyps) == tensor(0.0)
assert bleu_score(hyps, refs) == tensor(0.0)


def test_bleu_empty_class():
bleu = BLEUScore()
hyp = [[]]
ref = [[[]]]
assert bleu(ref, hyp) == tensor(0.0)
assert bleu(hyp, ref) == tensor(0.0)


def test_no_4_gram_class():
bleu = BLEUScore()
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu(refs, hyps) == tensor(0.0)
assert bleu(hyps, refs) == tensor(0.0)
16 changes: 8 additions & 8 deletions tests/text/test_chrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@


def sacrebleu_chrf_fn(
targets: Sequence[Sequence[str]],
preds: Sequence[str],
targets: Sequence[Sequence[str]],
char_order: int,
word_order: int,
lowercase: bool,
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_chrf_score_class(
sk_metric=nltk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace):
Expand All @@ -91,7 +91,7 @@ def test_chrf_score_functional(self, preds, targets, char_order, word_order, low
metric_functional=chrf_score,
sk_metric=nltk_metric,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_chrf_score_differentiability(self, preds, targets, char_order, word_order, lowercase, whitespace):
Expand All @@ -108,33 +108,33 @@ def test_chrf_score_differentiability(self, preds, targets, char_order, word_ord
metric_module=CHRFScore,
metric_functional=chrf_score,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)


def test_chrf_empty_functional():
hyp = []
ref = [[]]
assert chrf_score(ref, hyp) == tensor(0.0)
assert chrf_score(hyp, ref) == tensor(0.0)


def test_chrf_empty_class():
chrf = CHRFScore()
hyp = []
ref = [[]]
assert chrf(ref, hyp) == tensor(0.0)
assert chrf(hyp, ref) == tensor(0.0)


def test_chrf_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, chrf_sentence_score = chrf_score(ref, hyp, return_sentence_level_score=True)
_, chrf_sentence_score = chrf_score(hyp, ref, return_sentence_level_score=True)
isinstance(chrf_sentence_score, Tensor)


def test_chrf_return_sentence_level_class():
chrf = CHRFScore(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, chrf_sentence_score = chrf(ref, hyp)
_, chrf_sentence_score = chrf(hyp, ref)
isinstance(chrf_sentence_score, Tensor)
8 changes: 4 additions & 4 deletions tests/text/test_sacre_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
TOKENIZERS = ("none", "13a", "zh", "intl", "char")


def sacrebleu_fn(targets: Sequence[Sequence[str]], preds: Sequence[str], tokenize: str, lowercase: bool) -> Tensor:
def sacrebleu_fn(preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool) -> Tensor:
sacrebleu_fn = BLEU(tokenize=tokenize, lowercase=lowercase)
# Sacrebleu expects different format of input
targets = [[target[i] for target in targets] for i in range(len(targets[0]))]
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize
sk_metric=original_sacrebleu,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_functional(self, preds, targets, tokenize, lowercase):
Expand All @@ -74,7 +74,7 @@ def test_bleu_score_functional(self, preds, targets, tokenize, lowercase):
metric_functional=sacre_bleu_score,
sk_metric=original_sacrebleu,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase):
Expand All @@ -86,5 +86,5 @@ def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase)
metric_module=SacreBLEUScore,
metric_functional=sacre_bleu_score,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)
20 changes: 10 additions & 10 deletions tests/text/test_ter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@


def sacrebleu_ter_fn(
targets: Sequence[Sequence[str]],
preds: Sequence[str],
targets: Sequence[Sequence[str]],
normalized: bool,
no_punct: bool,
asian_support: bool,
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_chrf_score_class(
sk_metric=nltk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
Expand All @@ -99,7 +99,7 @@ def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, a
metric_functional=ter,
sk_metric=nltk_metric,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)

def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctuation, asian_support, lowercase):
Expand All @@ -116,46 +116,46 @@ def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctu
metric_module=TER,
metric_functional=ter,
metric_args=metric_args,
input_order=INPUT_ORDER.TARGETS_FIRST,
input_order=INPUT_ORDER.PREDS_FIRST,
)


def test_ter_empty_functional():
hyp = []
ref = [[]]
assert ter(ref, hyp) == tensor(0.0)
assert ter(hyp, ref) == tensor(0.0)


def test_ter_empty_class():
ter_metric = TER()
hyp = []
ref = [[]]
assert ter_metric(ref, hyp) == tensor(0.0)
assert ter_metric(hyp, ref) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_functional():
hyp = ["python"]
ref = [[]]
assert ter(ref, hyp) == tensor(0.0)
assert ter(hyp, ref) == tensor(0.0)


def test_ter_empty_with_non_empty_hyp_class():
ter_metric = TER()
hyp = ["python"]
ref = [[]]
assert ter_metric(ref, hyp) == tensor(0.0)
assert ter_metric(hyp, ref) == tensor(0.0)


def test_ter_return_sentence_level_score_functional():
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter(ref, hyp, return_sentence_level_score=True)
_, sentence_ter = ter(hyp, ref, return_sentence_level_score=True)
isinstance(sentence_ter, Tensor)


def test_ter_return_sentence_level_class():
ter_metric = TER(return_sentence_level_score=True)
hyp = _inputs_single_sentence_multiple_references.preds
ref = _inputs_single_sentence_multiple_references.targets
_, sentence_ter = ter_metric(ref, hyp)
_, sentence_ter = ter_metric(hyp, ref)
isinstance(sentence_ter, Tensor)
11 changes: 6 additions & 5 deletions torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@ class CalibrationError(Metric):
L1 norm (Expected Calibration Error)
.. math::
\text{ECE} = \frac{1}{N}\sum_i^N \|(p_i - c_i)\|
\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|
Infinity norm (Maximum Calibration Error)
.. math::
\text{RMSCE} = \max_{i} (p_i - c_i)
\text{MCE} = \max_{i} (p_i - c_i)
L2 norm (Root Mean Square Calibration Error)
.. math::
\text{MCE} = \frac{1}{N}\sum_i^N (p_i - c_i)^2
\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}
Where :math:`p_i` is the top-1 prediction accuracy in bin i
and :math:`c_i` is the average confidence of predictions in bin i.
Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`,
:math:`c_i` is the average confidence of predictions in bin :math:`i`, and
:math:`b_i` is the fraction of data points in bin :math:`i`.
.. note::
L2-norm debiasing is not yet supported.
Expand Down
11 changes: 6 additions & 5 deletions torchmetrics/functional/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,21 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str
L1 norm (Expected Calibration Error)
.. math::
\text{ECE} = \frac{1}{N}\sum_i^N \|(p_i - c_i)\|
\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|
Infinity norm (Maximum Calibration Error)
.. math::
\text{RMSCE} = \max_{i} (p_i - c_i)
\text{MCE} = \max_{i} (p_i - c_i)
L2 norm (Root Mean Square Calibration Error)
.. math::
\text{MCE} = \frac{1}{N}\sum_i^N (p_i - c_i)^2
\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}
Where :math:`p_i` is the top-1 prediction accuracy in
bin i and :math:`c_i` is the average confidence of predictions in bin i.
Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`,
:math:`c_i` is the average confidence of predictions in bin :math:`i`, and
:math:`b_i` is the fraction of data points in bin :math:`i`.
.. note:
L2-norm debiasing is not yet supported.
Expand Down
Loading

0 comments on commit bec6b79

Please sign in to comment.