From a66ee8a1463116949d301db9079e6026042330fc Mon Sep 17 00:00:00 2001 From: Ashutosh Kumar Date: Wed, 5 Jan 2022 00:47:14 +0530 Subject: [PATCH 1/2] Unify the input order for text (NLG) metrics - BLEU, SacreBLEU, TER, CHRF (#696) * Standardize BLEU and CHRF * Standardize TER metric * Standardize SacreBLEU * Add warnings for breaking change files * Update docstring + CHANGELOG.md * Update order and keep naming unchanged * Apply suggestions from code review Co-authored-by: Jirka Borovec Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 5 +- tests/text/test_bleu.py | 16 ++-- tests/text/test_chrf.py | 16 ++-- tests/text/test_sacre_bleu.py | 8 +- tests/text/test_ter.py | 20 ++--- torchmetrics/functional/text/bleu.py | 28 ++++--- torchmetrics/functional/text/chrf.py | 94 +++++++++++----------- torchmetrics/functional/text/sacre_bleu.py | 15 ++-- torchmetrics/functional/text/ter.py | 70 ++++++++-------- torchmetrics/text/bleu.py | 14 ++-- torchmetrics/text/chrf.py | 14 ++-- torchmetrics/text/sacre_bleu.py | 14 +++- torchmetrics/text/ter.py | 10 +-- 13 files changed, 176 insertions(+), 148 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6639ec0e85..521c6172be8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Metrics having third party dependencies removed from global import ([#463](https://github.com/PyTorchLightning/metrics/pull/463)) -- `BLEUScore` now expects untokenized input to stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640)) +- Untokenized for `BLEUScore` input stay consistent with all the other text metrics ([#640](https://github.com/PyTorchLightning/metrics/pull/640)) + + +- Arguments reordered for `TER`, `BLEUScore`, `SacreBLEUScore`, `CHRFScore` now expect input order as predictions first and target second ([#696](https://github.com/PyTorchLightning/metrics/pull/696)) - Renamed `torchmetrics.collections` to `torchmetrics.metrics_collections` to avoid clashing with system's `collections` package ([#695](https://github.com/PyTorchLightning/metrics/pull/695)) diff --git a/tests/text/test_bleu.py b/tests/text/test_bleu.py index 1866094fb6f..48bbb12633d 100644 --- a/tests/text/test_bleu.py +++ b/tests/text/test_bleu.py @@ -27,7 +27,7 @@ smooth_func = SmoothingFunction().method2 -def _compute_bleu_metric_nltk(list_of_references, hypotheses, weights, smoothing_function, **kwargs): +def _compute_bleu_metric_nltk(hypotheses, list_of_references, weights, smoothing_function, **kwargs): hypotheses_ = [hypothesis.split() for hypothesis in hypotheses] list_of_references_ = [[line.split() for line in ref] for ref in list_of_references] return corpus_bleu( @@ -67,7 +67,7 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, sk_metric=compute_bleu_metric_nltk, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth): @@ -80,7 +80,7 @@ def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_fun metric_functional=bleu_score, sk_metric=compute_bleu_metric_nltk, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smooth_func, smooth): @@ -92,31 +92,31 @@ def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smo metric_module=BLEUScore, metric_functional=bleu_score, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_empty_functional(): hyp = [[]] ref = [[[]]] - assert bleu_score(ref, hyp) == tensor(0.0) + assert bleu_score(hyp, ref) == tensor(0.0) def test_no_4_gram_functional(): hyps = ["My full pytorch-lightning"] refs = [["My full pytorch-lightning test", "Completely Different"]] - assert bleu_score(refs, hyps) == tensor(0.0) + assert bleu_score(hyps, refs) == tensor(0.0) def test_bleu_empty_class(): bleu = BLEUScore() hyp = [[]] ref = [[[]]] - assert bleu(ref, hyp) == tensor(0.0) + assert bleu(hyp, ref) == tensor(0.0) def test_no_4_gram_class(): bleu = BLEUScore() hyps = ["My full pytorch-lightning"] refs = [["My full pytorch-lightning test", "Completely Different"]] - assert bleu(refs, hyps) == tensor(0.0) + assert bleu(hyps, refs) == tensor(0.0) diff --git a/tests/text/test_chrf.py b/tests/text/test_chrf.py index 76743b6ecd4..4863d850dfc 100644 --- a/tests/text/test_chrf.py +++ b/tests/text/test_chrf.py @@ -15,8 +15,8 @@ def sacrebleu_chrf_fn( - targets: Sequence[Sequence[str]], preds: Sequence[str], + targets: Sequence[Sequence[str]], char_order: int, word_order: int, lowercase: bool, @@ -71,7 +71,7 @@ def test_chrf_score_class( sk_metric=nltk_metric, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace): @@ -91,7 +91,7 @@ def test_chrf_score_functional(self, preds, targets, char_order, word_order, low metric_functional=chrf_score, sk_metric=nltk_metric, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_chrf_score_differentiability(self, preds, targets, char_order, word_order, lowercase, whitespace): @@ -108,27 +108,27 @@ def test_chrf_score_differentiability(self, preds, targets, char_order, word_ord metric_module=CHRFScore, metric_functional=chrf_score, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_chrf_empty_functional(): hyp = [] ref = [[]] - assert chrf_score(ref, hyp) == tensor(0.0) + assert chrf_score(hyp, ref) == tensor(0.0) def test_chrf_empty_class(): chrf = CHRFScore() hyp = [] ref = [[]] - assert chrf(ref, hyp) == tensor(0.0) + assert chrf(hyp, ref) == tensor(0.0) def test_chrf_return_sentence_level_score_functional(): hyp = _inputs_single_sentence_multiple_references.preds ref = _inputs_single_sentence_multiple_references.targets - _, chrf_sentence_score = chrf_score(ref, hyp, return_sentence_level_score=True) + _, chrf_sentence_score = chrf_score(hyp, ref, return_sentence_level_score=True) isinstance(chrf_sentence_score, Tensor) @@ -136,5 +136,5 @@ def test_chrf_return_sentence_level_class(): chrf = CHRFScore(return_sentence_level_score=True) hyp = _inputs_single_sentence_multiple_references.preds ref = _inputs_single_sentence_multiple_references.targets - _, chrf_sentence_score = chrf(ref, hyp) + _, chrf_sentence_score = chrf(hyp, ref) isinstance(chrf_sentence_score, Tensor) diff --git a/tests/text/test_sacre_bleu.py b/tests/text/test_sacre_bleu.py index 8cd34a807ff..6cbe0aa8328 100644 --- a/tests/text/test_sacre_bleu.py +++ b/tests/text/test_sacre_bleu.py @@ -31,7 +31,7 @@ TOKENIZERS = ("none", "13a", "zh", "intl", "char") -def sacrebleu_fn(targets: Sequence[Sequence[str]], preds: Sequence[str], tokenize: str, lowercase: bool) -> Tensor: +def sacrebleu_fn(preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool) -> Tensor: sacrebleu_fn = BLEU(tokenize=tokenize, lowercase=lowercase) # Sacrebleu expects different format of input targets = [[target[i] for target in targets] for i in range(len(targets[0]))] @@ -61,7 +61,7 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize sk_metric=original_sacrebleu, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): @@ -74,7 +74,7 @@ def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): metric_functional=sacre_bleu_score, sk_metric=original_sacrebleu, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase): @@ -86,5 +86,5 @@ def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase) metric_module=SacreBLEUScore, metric_functional=sacre_bleu_score, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_ter.py b/tests/text/test_ter.py index 4f49cc4665c..50c38049031 100644 --- a/tests/text/test_ter.py +++ b/tests/text/test_ter.py @@ -15,8 +15,8 @@ def sacrebleu_ter_fn( - targets: Sequence[Sequence[str]], preds: Sequence[str], + targets: Sequence[Sequence[str]], normalized: bool, no_punct: bool, asian_support: bool, @@ -75,7 +75,7 @@ def test_chrf_score_class( sk_metric=nltk_metric, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase): @@ -99,7 +99,7 @@ def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, a metric_functional=ter, sk_metric=nltk_metric, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctuation, asian_support, lowercase): @@ -116,40 +116,40 @@ def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctu metric_module=TER, metric_functional=ter, metric_args=metric_args, - input_order=INPUT_ORDER.TARGETS_FIRST, + input_order=INPUT_ORDER.PREDS_FIRST, ) def test_ter_empty_functional(): hyp = [] ref = [[]] - assert ter(ref, hyp) == tensor(0.0) + assert ter(hyp, ref) == tensor(0.0) def test_ter_empty_class(): ter_metric = TER() hyp = [] ref = [[]] - assert ter_metric(ref, hyp) == tensor(0.0) + assert ter_metric(hyp, ref) == tensor(0.0) def test_ter_empty_with_non_empty_hyp_functional(): hyp = ["python"] ref = [[]] - assert ter(ref, hyp) == tensor(0.0) + assert ter(hyp, ref) == tensor(0.0) def test_ter_empty_with_non_empty_hyp_class(): ter_metric = TER() hyp = ["python"] ref = [[]] - assert ter_metric(ref, hyp) == tensor(0.0) + assert ter_metric(hyp, ref) == tensor(0.0) def test_ter_return_sentence_level_score_functional(): hyp = _inputs_single_sentence_multiple_references.preds ref = _inputs_single_sentence_multiple_references.targets - _, sentence_ter = ter(ref, hyp, return_sentence_level_score=True) + _, sentence_ter = ter(hyp, ref, return_sentence_level_score=True) isinstance(sentence_ter, Tensor) @@ -157,5 +157,5 @@ def test_ter_return_sentence_level_class(): ter_metric = TER(return_sentence_level_score=True) hyp = _inputs_single_sentence_multiple_references.preds ref = _inputs_single_sentence_multiple_references.targets - _, sentence_ter = ter_metric(ref, hyp) + _, sentence_ter = ter_metric(hyp, ref) isinstance(sentence_ter, Tensor) diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 81b6e100e9e..19176f229fd 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -16,6 +16,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +import warnings from collections import Counter from typing import Callable, Sequence, Tuple, Union @@ -57,8 +58,8 @@ def _tokenize_fn(sentence: str) -> Sequence[str]: def _bleu_score_update( - reference_corpus: Sequence[Sequence[str]], translate_corpus: Sequence[str], + reference_corpus: Sequence[Sequence[str]], numerator: Tensor, denominator: Tensor, trans_len: Tensor, @@ -69,11 +70,11 @@ def _bleu_score_update( """Updates and returns variables required to compute the BLEU score. Args: - reference_corpus: An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus numerator: Numerator of precision score (true positives) denominator: Denominator of precision score (true positives + false positives) - trans_len: count of words in a candidate translation + trans_len: count of words in a candidate prediction ref_len: count of words in a reference translation n_gram: gram value ranged 1 to 4 tokenizer: A function that turns sentence into list of words @@ -106,7 +107,12 @@ def _bleu_score_update( def _bleu_score_compute( - trans_len: Tensor, ref_len: Tensor, numerator: Tensor, denominator: Tensor, n_gram: int = 4, smooth: bool = False + trans_len: Tensor, + ref_len: Tensor, + numerator: Tensor, + denominator: Tensor, + n_gram: int = 4, + smooth: bool = False, ) -> Tensor: """Computes the BLEU score. @@ -140,18 +146,18 @@ def _bleu_score_compute( def bleu_score( - reference_corpus: Sequence[Union[str, Sequence[str]]], translate_corpus: Union[str, Sequence[str]], + reference_corpus: Sequence[Union[str, Sequence[str]]], n_gram: int = 4, smooth: bool = False, ) -> Tensor: """Calculate `BLEU score`_ of machine translated text with one or more references. Args: - reference_corpus: - An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus + reference_corpus: + An iterable of iterables of reference corpus n_gram: Gram value ranged from 1 to 4 (Default 4) smooth: @@ -164,7 +170,7 @@ def bleu_score( >>> from torchmetrics.functional import bleu_score >>> translate_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] - >>> bleu_score(reference_corpus, translate_corpus) + >>> bleu_score(translate_corpus, reference_corpus) tensor(0.7598) References: @@ -174,6 +180,10 @@ def bleu_score( [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ + warnings.warn( + "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." + " Warning will be removed in v0.8." + ) translate_corpus_ = [translate_corpus] if isinstance(translate_corpus, str) else translate_corpus reference_corpus_ = [ [reference_text] if isinstance(reference_text, str) else reference_text for reference_text in reference_corpus @@ -188,7 +198,7 @@ def bleu_score( ref_len = tensor(0, dtype=torch.float) trans_len, ref_len = _bleu_score_update( - reference_corpus_, translate_corpus_, numerator, denominator, trans_len, ref_len, n_gram, _tokenize_fn + translate_corpus_, reference_corpus_, numerator, denominator, trans_len, ref_len, n_gram, _tokenize_fn ) return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth) diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index fd646742f27..ba8b2214931 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -72,10 +72,10 @@ def _prepare_n_grams_dicts( total_matching_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} return ( - total_ref_char_n_grams, - total_ref_word_n_grams, total_hyp_char_n_grams, total_hyp_word_n_grams, + total_ref_char_n_grams, + total_ref_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, ) @@ -209,16 +209,17 @@ def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) def _get_ngram_matches( - ref_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], hyp_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], + ref_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], ) -> Dict[int, Tensor]: """Get a number of n-gram matches between reference and hypothesis n-grams. Args: - ref_n_grams_counts: + hyp_n_grams_counts: ref_n_grams_counts: Return: + matching_n_grams """ matching_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) for n in hyp_n_grams_counts: @@ -251,10 +252,10 @@ def _sum_over_dicts(total_n_grams: Dict[int, Tensor], n_grams: Dict[int, Tensor] def _calculate_fscore( matching_char_n_grams: Dict[int, Tensor], matching_word_n_grams: Dict[int, Tensor], - ref_char_n_grams: Dict[int, Tensor], - ref_word_n_grams: Dict[int, Tensor], hyp_char_n_grams: Dict[int, Tensor], hyp_word_n_grams: Dict[int, Tensor], + ref_char_n_grams: Dict[int, Tensor], + ref_word_n_grams: Dict[int, Tensor], n_order: float, beta: float, ) -> Tensor: @@ -266,14 +267,14 @@ def _calculate_fscore( A total number of matching character n-grams between the best matching reference and hypothesis. matching_word_n_grams: A total number of matching word n-grams between the best matching reference and hypothesis. - ref_char_n_grams: - A total number of reference character n-grams. - ref_word_n_grams: - A total number of reference word n-grams. hyp_char_n_grams: A total number of hypothesis character n-grams. hyp_word_n_grams: A total number of hypothesis word n-grams. + ref_char_n_grams: + A total number of reference character n-grams. + ref_word_n_grams: + A total number of reference word n-grams. n_order: A sum of character and word n-gram order. beta: @@ -383,10 +384,10 @@ def _calculate_sentence_level_chrf_score( f_score = _calculate_fscore( matching_char_n_grams, matching_word_n_grams, - ref_char_n_grams, - ref_word_n_grams, hyp_char_n_grams, hyp_word_n_grams, + ref_char_n_grams, + ref_word_n_grams, n_order, beta, ) @@ -408,12 +409,12 @@ def _calculate_sentence_level_chrf_score( def _chrf_score_update( - reference_corpus: Union[Sequence[str], Sequence[Sequence[str]]], hypothesis_corpus: Union[str, Sequence[str]], - total_ref_char_n_grams: Dict[int, Tensor], - total_ref_word_n_grams: Dict[int, Tensor], + reference_corpus: Union[Sequence[str], Sequence[Sequence[str]]], total_hyp_char_n_grams: Dict[int, Tensor], total_hyp_word_n_grams: Dict[int, Tensor], + total_ref_char_n_grams: Dict[int, Tensor], + total_ref_word_n_grams: Dict[int, Tensor], total_matching_char_n_grams: Dict[int, Tensor], total_matching_word_n_grams: Dict[int, Tensor], n_char_order: int, @@ -434,18 +435,18 @@ def _chrf_score_update( ]: """ Args: - reference_corpus: - An iterable of iterables of reference corpus. hypothesis_corpus: An iterable of hypothesis corpus. - total_ref_char_n_grams: - A dictionary containing a total number of reference character n-grams. - total_ref_word_n_grams: - A dictionary containing a total number of reference word n-grams. + reference_corpus: + An iterable of iterables of reference corpus. total_hyp_char_n_grams: A dictionary containing a total number of hypothesis character n-grams. total_hyp_word_n_grams: A dictionary containing a total number of hypothesis word n-grams. + total_ref_char_n_grams: + A dictionary containing a total number of reference character n-grams. + total_ref_word_n_grams: + A dictionary containing a total number of reference word n-grams. total_matching_char_n_grams: A dictionary containing a total number of matching character n-grams between references and hypotheses. total_matching_word_n_grams: @@ -489,7 +490,7 @@ def _chrf_score_update( """ reference_corpus, hypothesis_corpus = _validate_inputs(reference_corpus, hypothesis_corpus) - for (references, hypothesis) in zip(reference_corpus, hypothesis_corpus): + for (hypothesis, references) in zip(hypothesis_corpus, reference_corpus): ( hyp_char_n_grams_counts, hyp_word_n_grams_counts, @@ -528,10 +529,10 @@ def _chrf_score_update( total_matching_word_n_grams = _sum_over_dicts(total_matching_word_n_grams, matching_word_n_grams) return ( - total_ref_char_n_grams, - total_ref_word_n_grams, total_hyp_char_n_grams, total_hyp_word_n_grams, + total_ref_char_n_grams, + total_ref_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, sentence_chrf_score, @@ -539,27 +540,26 @@ def _chrf_score_update( def _chrf_score_compute( - total_ref_char_n_grams: Dict[int, Tensor], - total_ref_word_n_grams: Dict[int, Tensor], total_hyp_char_n_grams: Dict[int, Tensor], total_hyp_word_n_grams: Dict[int, Tensor], + total_ref_char_n_grams: Dict[int, Tensor], + total_ref_word_n_grams: Dict[int, Tensor], total_matching_char_n_grams: Dict[int, Tensor], total_matching_word_n_grams: Dict[int, Tensor], n_order: float, beta: float, ) -> Tensor: - """Compute chrF/chrF++ score based on pre-computed reference, hypothesis and matching character and word - n-grams. + """Compute chrF/chrF++ score based on pre-computed target, prediction and matching character and word n-grams. Args: - total_ref_char_n_grams: - A dictionary containing a total number of reference character n-grams. - total_ref_word_n_grams: - A dictionary containing a total number of reference word n-grams. total_hyp_char_n_grams: A dictionary containing a total number of hypothesis character n-grams. total_hyp_word_n_grams: A dictionary containing a total number of hypothesis word n-grams. + total_ref_char_n_grams: + A dictionary containing a total number of reference character n-grams. + total_ref_word_n_grams: + A dictionary containing a total number of reference word n-grams. total_matching_char_n_grams: A dictionary containing a total number of matching character n-grams between references and hypotheses. total_matching_word_n_grams: @@ -575,10 +575,10 @@ def _chrf_score_compute( chrf_f_score = _calculate_fscore( total_matching_char_n_grams, total_matching_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, total_hyp_char_n_grams, total_hyp_word_n_grams, + total_ref_char_n_grams, + total_ref_word_n_grams, n_order, beta, ) @@ -586,8 +586,8 @@ def _chrf_score_compute( def chrf_score( - reference_corpus: Union[Sequence[str], Sequence[Sequence[str]]], hypothesis_corpus: Union[str, Sequence[str]], + reference_corpus: Union[Sequence[str], Sequence[Sequence[str]]], n_char_order: int = 6, n_word_order: int = 2, beta: float = 2.0, @@ -601,10 +601,10 @@ def chrf_score( https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. Args: - reference_corpus: - An iterable of iterables of reference corpus. hypothesis_corpus: An iterable of hypothesis corpus. + reference_corpus: + An iterable of iterables of reference corpus. n_char_order: A character n-gram order. If `n_char_order=6`, the metrics refers to the official chrF/chrF++. n_word_order: @@ -635,7 +635,7 @@ def chrf_score( >>> from torchmetrics.functional import chrf_score >>> hypothesis_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] - >>> chrf_score(reference_corpus, hypothesis_corpus) + >>> chrf_score(hypothesis_corpus, reference_corpus) tensor(0.8640) References: @@ -652,10 +652,10 @@ def chrf_score( n_order = float(n_char_order + n_word_order) ( - total_ref_char_n_grams, - total_ref_word_n_grams, total_hyp_char_n_grams, total_hyp_word_n_grams, + total_ref_char_n_grams, + total_ref_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, ) = _prepare_n_grams_dicts(n_char_order, n_word_order) @@ -663,20 +663,20 @@ def chrf_score( sentence_chrf_score: Optional[List[Tensor]] = [] if return_sentence_level_score else None ( - total_ref_char_n_grams, - total_ref_word_n_grams, total_hyp_char_n_grams, total_hyp_word_n_grams, + total_ref_char_n_grams, + total_ref_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, sentence_chrf_score, ) = _chrf_score_update( - reference_corpus, hypothesis_corpus, - total_ref_char_n_grams, - total_ref_word_n_grams, + reference_corpus, total_hyp_char_n_grams, total_hyp_word_n_grams, + total_ref_char_n_grams, + total_ref_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, n_char_order, @@ -689,10 +689,10 @@ def chrf_score( ) chrf_f_score = _chrf_score_compute( - total_ref_char_n_grams, - total_ref_word_n_grams, total_hyp_char_n_grams, total_hyp_word_n_grams, + total_ref_char_n_grams, + total_ref_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, n_order, diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index e42409c4f3c..835607579a0 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -39,6 +39,7 @@ import re +import warnings from functools import partial from typing import Sequence @@ -277,8 +278,8 @@ def _lower(line: str, lowercase: bool) -> str: def sacre_bleu_score( - reference_corpus: Sequence[Sequence[str]], translate_corpus: Sequence[str], + reference_corpus: Sequence[Sequence[str]], n_gram: int = 4, smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", @@ -288,10 +289,10 @@ def sacre_bleu_score( follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu. Args: - reference_corpus: - An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus + reference_corpus: + An iterable of iterables of reference corpus n_gram: Gram value ranged from 1 to 4 (Default 4) smooth: @@ -309,7 +310,7 @@ def sacre_bleu_score( >>> from torchmetrics.functional import sacre_bleu_score >>> translate_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] - >>> sacre_bleu_score(reference_corpus, translate_corpus) + >>> sacre_bleu_score(translate_corpus, reference_corpus) tensor(0.7598) References: @@ -321,6 +322,10 @@ def sacre_bleu_score( [3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ + warnings.warn( + "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." + " Warning will be removed in v0.8." + ) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") @@ -343,8 +348,8 @@ def sacre_bleu_score( tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase) trans_len, ref_len = _bleu_score_update( - reference_corpus, translate_corpus, + reference_corpus, numerator, denominator, trans_len, diff --git a/torchmetrics/functional/text/ter.py b/torchmetrics/functional/text/ter.py index c63efb666bd..ba217c47225 100644 --- a/torchmetrics/functional/text/ter.py +++ b/torchmetrics/functional/text/ter.py @@ -206,24 +206,24 @@ def _preprocess_sentence(sentence: str, tokenizer: _TercomTokenizer) -> str: return tokenizer(sentence.rstrip()) -def _find_shifted_pairs(reference_words: List[str], hypothesis_words: List[str]) -> Iterator[Tuple[int, int, int]]: +def _find_shifted_pairs(hypothesis_words: List[str], reference_words: List[str]) -> Iterator[Tuple[int, int, int]]: """Find matching word sub-sequences in two lists of words. Ignores sub-sequences starting at the same position. Args: - reference_words: - A list of a tokenized reference sentence. hypothesis_words: A list of a tokenized hypothesis sentence. + reference_words: + A list of a tokenized reference sentence. Return: Yields tuples of `(reference_start, hypothesis_start, length` such that: reference_words[reference_start : reference_start + length] ==\ hypothesis_words[hypothesis_start : hypothesis_start + length] - reference_start: - A list of reference start indices. hypothesis_start: A list of hypothesis start indices. + reference_start: + A list of reference start indices. length: A length of a word span to be considered. """ @@ -238,7 +238,7 @@ def _find_shifted_pairs(reference_words: List[str], hypothesis_words: List[str]) # Check if hypothesis and reference are equal so far if hypothesis_words[hypothesis_start + length - 1] != reference_words[reference_start + length - 1]: break - yield reference_start, hypothesis_start, length + yield hypothesis_start, reference_start, length # Stop processing once a sequence is consumed. _hyp = len(hypothesis_words) == hypothesis_start + length @@ -249,10 +249,10 @@ def _find_shifted_pairs(reference_words: List[str], hypothesis_words: List[str]) def _handle_corner_cases_during_shifting( alignments: Dict[int, int], - reference_errors: List[int], hypothesis_errors: List[int], - reference_start: int, + reference_errors: List[int], hypothesis_start: int, + reference_start: int, length: int, ) -> bool: """A helper function which returns `True` if any of corner cases has been met. Otherwise, `False` is returned. @@ -260,14 +260,14 @@ def _handle_corner_cases_during_shifting( Args: alignments: A dictionary mapping aligned positions between a reference and a hypothesis. - reference_errors: - A list of error positions in a reference. hypothesis_errors: A list of error positions in a hypothesis. - reference_start: - A reference start index. + reference_errors: + A list of error positions in a reference. hypothesis_start: A hypothesis start index. + reference_start: + A reference start index. length: A length of a word span to be considered. @@ -327,8 +327,8 @@ def _shift_word_within_shifted_string(words: List[str], start: int, target: int, def _shift_words( - reference_words: List[str], hypothesis_words: List[str], + reference_words: List[str], cached_edit_distance: _LevenshteinEditDistance, checked_candidates: int, ) -> Tuple[int, List[str], int]: @@ -340,10 +340,10 @@ def _shift_words( choices. (The paragraph copied from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/lib_ter.py) Args: - reference_words: - A list of lists of tokenized reference sentences. hypothesis_words: A list of tokenized hypothesis sentence. + reference_words: + A list of lists of tokenized reference sentences. cached_edit_distance: A pre-computed edit distance between a hypothesis and a reference. checked_candidates: @@ -363,9 +363,9 @@ def _shift_words( best: Optional[Tuple[int, int, int, int, List[str]]] = None - for reference_start, hypothesis_start, length in _find_shifted_pairs(reference_words, hypothesis_words): + for hypothesis_start, reference_start, length in _find_shifted_pairs(hypothesis_words, reference_words): if _handle_corner_cases_during_shifting( - alignments, reference_errors, hypothesis_errors, reference_start, hypothesis_start, length + alignments, hypothesis_errors, reference_errors, hypothesis_start, reference_start, length ): continue @@ -409,14 +409,14 @@ def _shift_words( return best_score, shifted_words, checked_candidates -def _translation_edit_rate(reference_words: List[str], hypothesis_words: List[str]) -> Tensor: - """Compute translation edit rate between reference and hypothesis sentences. +def _translation_edit_rate(hypothesis_words: List[str], reference_words: List[str]) -> Tensor: + """Compute translation edit rate between hypothesis and reference sentences. Args: - reference_words: - A list of lists of tokenized reference sentences. hypothesis_words: A list of a tokenized hypothesis sentence. + reference_words: + A list of lists of tokenized reference sentences. Return: A number of required edits to match hypothesis and reference sentences. @@ -432,7 +432,7 @@ def _translation_edit_rate(reference_words: List[str], hypothesis_words: List[st while True: # do shifts until they stop reducing the edit distance delta, new_input_words, checked_candidates = _shift_words( - reference_words, input_words, cached_edit_distance, checked_candidates + input_words, reference_words, cached_edit_distance, checked_candidates ) if checked_candidates >= _MAX_SHIFT_CANDIDATES or delta <= 0: break @@ -446,15 +446,15 @@ def _translation_edit_rate(reference_words: List[str], hypothesis_words: List[st def _compute_sentence_statistics( - references_words: List[List[str]], hypothesis_words: List[str] + hypothesis_words: List[str], references_words: List[List[str]] ) -> Tuple[Tensor, Tensor]: """Compute sentence TER statistics between hypothesis and provided references. Args: - reference_words: - A list of lists of tokenized reference sentences. hypothesis_words: A list of tokenized hypothesis sentence. + reference_words: + A list of lists of tokenized reference sentences. Return: best_num_edits: @@ -496,8 +496,8 @@ def _compute_ter_score_from_statistics(num_edits: Tensor, ref_length: Tensor) -> def _ter_update( - reference_corpus: Sequence[Union[str, Sequence[str]]], hypothesis_corpus: Union[str, Sequence[str]], + reference_corpus: Sequence[Union[str, Sequence[str]]], tokenizer: _TercomTokenizer, total_num_edits: Tensor, total_ref_length: Tensor, @@ -506,10 +506,10 @@ def _ter_update( """Update TER statistics. Args: - reference_corpus: - An iterable of iterables of reference corpus. hypothesis_corpus: An iterable of hypothesis corpus. + reference_corpus: + An iterable of iterables of reference corpus. tokenizer: total_num_edits: A total number of required edits to match hypothesis and reference sentences. @@ -530,12 +530,12 @@ def _ter_update( """ reference_corpus, hypothesis_corpus = _validate_inputs(reference_corpus, hypothesis_corpus) - for (references, hypothesis) in zip(reference_corpus, hypothesis_corpus): + for (hypothesis, references) in zip(hypothesis_corpus, reference_corpus): references_words_: List[List[str]] = [ [word for word in _preprocess_sentence(ref, tokenizer).split()] for ref in references ] hypothesis_words_: List[str] = [word for word in _preprocess_sentence(hypothesis, tokenizer).split()] - num_edits, ref_length = _compute_sentence_statistics(references_words_, hypothesis_words_) + num_edits, ref_length = _compute_sentence_statistics(hypothesis_words_, references_words_) total_num_edits += num_edits total_ref_length += ref_length if sentence_ter is not None: @@ -558,8 +558,8 @@ def _ter_compute(total_num_edits: Tensor, total_ref_length: Tensor) -> Tensor: def ter( - reference_corpus: Sequence[Union[str, Sequence[str]]], hypothesis_corpus: Union[str, Sequence[str]], + reference_corpus: Sequence[Union[str, Sequence[str]]], normalize: bool = False, no_punctuation: bool = False, lowercase: bool = True, @@ -572,10 +572,10 @@ def ter( near-exact reimplementation of the Tercom algorithm, produces identical results on all "sane" outputs. Args: - reference_corpus: - An iterable of iterables of reference corpus. hypothesis_corpus: An iterable of hypothesis corpus. + reference_corpus: + An iterable of iterables of reference corpus. normalize: An indication whether a general tokenization to be applied. no_punctuation: @@ -594,7 +594,7 @@ def ter( Example: >>> hypothesis_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] - >>> ter(reference_corpus, hypothesis_corpus) + >>> ter(hypothesis_corpus, reference_corpus) tensor(0.1538) References: @@ -617,7 +617,7 @@ def ter( sentence_ter: Optional[List[Tensor]] = [] if return_sentence_level_score else None total_num_edits, total_ref_length, sentence_ter = _ter_update( - reference_corpus, hypothesis_corpus, tokenizer, total_num_edits, total_ref_length, sentence_ter + hypothesis_corpus, reference_corpus, tokenizer, total_num_edits, total_ref_length, sentence_ter ) ter_score = _ter_compute(total_num_edits, total_ref_length) diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index c28bf9bc1df..10c1f5ab5a1 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -16,6 +16,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +import warnings from typing import Any, Callable, Optional, Sequence import torch @@ -48,7 +49,7 @@ class BLEUScore(Metric): >>> translate_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = BLEUScore() - >>> metric(reference_corpus, translate_corpus) + >>> metric(translate_corpus, reference_corpus) tensor(0.7598) References: @@ -81,7 +82,10 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) - + warnings.warn( + "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." + " Warning will be removed in v0.8." + ) self.n_gram = n_gram self.smooth = smooth @@ -91,18 +95,18 @@ def __init__( self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") def update( # type: ignore - self, reference_corpus: Sequence[Sequence[str]], translate_corpus: Sequence[str] + self, translate_corpus: Sequence[str], reference_corpus: Sequence[Sequence[str]] ) -> None: """Compute Precision Scores. Args: - reference_corpus: An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus """ self.trans_len, self.ref_len = _bleu_score_update( - reference_corpus, translate_corpus, + reference_corpus, self.numerator, self.denominator, self.trans_len, diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index b413402246c..1059f13cf37 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -30,10 +30,10 @@ _TEXT_LEVELS = ("ref", "hyp", "matching") _DICT_STATES_NAMES = ( - "total_ref_char_n_grams", - "total_ref_word_n_grams", "total_hyp_char_n_grams", "total_hyp_word_n_grams", + "total_ref_char_n_grams", + "total_ref_word_n_grams", "total_matching_char_n_grams", "total_matching_word_n_grams", ) @@ -86,7 +86,7 @@ class CHRFScore(Metric): >>> hypothesis_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = CHRFScore() - >>> metric(reference_corpus, hypothesis_corpus) + >>> metric(hypothesis_corpus, reference_corpus) tensor(0.8640) References: @@ -143,19 +143,19 @@ def __init__( self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") def update( # type: ignore - self, reference_corpus: Sequence[Sequence[str]], hypothesis_corpus: Sequence[str] + self, hypothesis_corpus: Sequence[str], reference_corpus: Sequence[Sequence[str]] ) -> None: """Compute Precision Scores. Args: - reference_corpus: - An iterable of iterables of reference corpus. hypothesis_corpus: An iterable of hypothesis corpus. + reference_corpus: + An iterable of iterables of reference corpus. """ n_grams_dicts_tuple = _chrf_score_update( - reference_corpus, hypothesis_corpus, + reference_corpus, *self._convert_states_to_dicts(), self.n_char_order, self.n_word_order, diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index b8c59f3c646..97ae380a4a1 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + # referenced from # Library Name: torchtext # Authors: torchtext authors and @sluks @@ -67,7 +69,7 @@ class SacreBLEUScore(BLEUScore): >>> translate_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = SacreBLEUScore() - >>> metric(reference_corpus, translate_corpus) + >>> metric(translate_corpus, reference_corpus) tensor(0.7598) References: @@ -99,6 +101,10 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) + warnings.warn( + "Input order of targets and preds were changed to predictions firsts and targets \ + second in v0.7. Warning will be removed in v0.8" + ) if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") @@ -110,17 +116,17 @@ def __init__( self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase) def update( # type: ignore - self, reference_corpus: Sequence[Sequence[str]], translate_corpus: Sequence[str] + self, translate_corpus: Sequence[str], reference_corpus: Sequence[Sequence[str]] ) -> None: """Compute Precision Scores. Args: - reference_corpus: An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus """ self.trans_len, self.ref_len = _bleu_score_update( - reference_corpus, translate_corpus, + reference_corpus, self.numerator, self.denominator, self.trans_len, diff --git a/torchmetrics/text/ter.py b/torchmetrics/text/ter.py index e9c1bd3cb15..662afeccc25 100644 --- a/torchmetrics/text/ter.py +++ b/torchmetrics/text/ter.py @@ -53,7 +53,7 @@ class TER(Metric): >>> hypothesis_corpus = ['the cat is on the mat'] >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = TER() - >>> metric(reference_corpus, hypothesis_corpus) + >>> metric(hypothesis_corpus, reference_corpus) tensor(0.1538) References: @@ -104,20 +104,20 @@ def __init__( def update( # type: ignore self, - reference_corpus: Sequence[Union[str, Sequence[str]]], hypothesis_corpus: Union[str, Sequence[str]], + reference_corpus: Sequence[Union[str, Sequence[str]]], ) -> None: """Update TER statistics. Args: - reference_corpus: - An iterable of iterables of reference corpus. hypothesis_corpus: An iterable of hypothesis corpus. + reference_corpus: + An iterable of iterables of reference corpus. """ self.total_num_edits, self.total_ref_len, self.sentence_ter = _ter_update( - reference_corpus, hypothesis_corpus, + reference_corpus, self.tokenizer, self.total_num_edits, self.total_ref_len, From 2e585965b684339c74480ca5cbed538ca193b6f6 Mon Sep 17 00:00:00 2001 From: Edward Williams Date: Tue, 4 Jan 2022 11:18:37 -0800 Subject: [PATCH 2/2] fixing documentation for calibration error (#702) * fixing documentation for calibration error * Apply suggestions from code review Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: Jirka Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .github/workflows/docs-check.yml | 4 ++-- .readthedocs.yml | 2 +- torchmetrics/classification/calibration_error.py | 11 ++++++----- .../functional/classification/calibration_error.py | 11 ++++++----- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index a503577f459..8a08019a3e2 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -12,7 +12,7 @@ jobs: - uses: actions/checkout@master - uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow @@ -47,7 +47,7 @@ jobs: - uses: actions/checkout@master - uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: 3.8 # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow diff --git a/.readthedocs.yml b/.readthedocs.yml index b5a64f98142..136d36b858b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -14,6 +14,6 @@ formats: all # Optionally set the version of Python and requirements required to build your docs python: - version: 3.7 + version: 3.8 install: - requirements: requirements/docs.txt diff --git a/torchmetrics/classification/calibration_error.py b/torchmetrics/classification/calibration_error.py index 14849e91ddb..3a8d0782efc 100644 --- a/torchmetrics/classification/calibration_error.py +++ b/torchmetrics/classification/calibration_error.py @@ -30,20 +30,21 @@ class CalibrationError(Metric): L1 norm (Expected Calibration Error) .. math:: - \text{ECE} = \frac{1}{N}\sum_i^N \|(p_i - c_i)\| + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\| Infinity norm (Maximum Calibration Error) .. math:: - \text{RMSCE} = \max_{i} (p_i - c_i) + \text{MCE} = \max_{i} (p_i - c_i) L2 norm (Root Mean Square Calibration Error) .. math:: - \text{MCE} = \frac{1}{N}\sum_i^N (p_i - c_i)^2 + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2} - Where :math:`p_i` is the top-1 prediction accuracy in bin i - and :math:`c_i` is the average confidence of predictions in bin i. + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, + :math:`c_i` is the average confidence of predictions in bin :math:`i`, and + :math:`b_i` is the fraction of data points in bin :math:`i`. .. note:: L2-norm debiasing is not yet supported. diff --git a/torchmetrics/functional/classification/calibration_error.py b/torchmetrics/functional/classification/calibration_error.py index 12e413c6783..32d493f9839 100644 --- a/torchmetrics/functional/classification/calibration_error.py +++ b/torchmetrics/functional/classification/calibration_error.py @@ -118,20 +118,21 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str L1 norm (Expected Calibration Error) .. math:: - \text{ECE} = \frac{1}{N}\sum_i^N \|(p_i - c_i)\| + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\| Infinity norm (Maximum Calibration Error) .. math:: - \text{RMSCE} = \max_{i} (p_i - c_i) + \text{MCE} = \max_{i} (p_i - c_i) L2 norm (Root Mean Square Calibration Error) .. math:: - \text{MCE} = \frac{1}{N}\sum_i^N (p_i - c_i)^2 + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2} - Where :math:`p_i` is the top-1 prediction accuracy in - bin i and :math:`c_i` is the average confidence of predictions in bin i. + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, + :math:`c_i` is the average confidence of predictions in bin :math:`i`, and + :math:`b_i` is the fraction of data points in bin :math:`i`. .. note: L2-norm debiasing is not yet supported.