From 6b6f8033e38210d969b3c765370e0565ca7823fe Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 8 Oct 2022 13:09:49 +0200 Subject: [PATCH 1/5] bugfix (Rouge): Evaluate pred_lsum only if lsum in rouge_keys Fixes #1257 Evaluate pred_lsum only if lsum in rouge_keys to avoid downloading "punkt" package from `nltk`, which raises an error when using e.g. DDP. --- src/torchmetrics/functional/text/rouge.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 1f3eb5fb950..0f6b0107fad 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -315,10 +315,11 @@ def _rouge_score_update( result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} list_results = [] pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer) - pred_lsum = [ - _normalize_and_tokenize_text(pred_sentence, stemmer, normalizer, tokenizer) - for pred_sentence in _split_sentence(pred_raw) - ] + if "Lsum" in rouge_keys_values: + pred_lsum = [ + _normalize_and_tokenize_text(pred_sentence, stemmer, normalizer, tokenizer) + for pred_sentence in _split_sentence(pred_raw) + ] for target_raw_inner in target_raw: tgt = _normalize_and_tokenize_text(target_raw_inner, stemmer, normalizer, tokenizer) From d28703a735eed54f300377fc4218a18ce0651314 Mon Sep 17 00:00:00 2001 From: stancld Date: Sat, 8 Oct 2022 13:14:44 +0200 Subject: [PATCH 2/5] chlog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29792ba357b..f0359ea2e57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed broken clone method for classification metrics ([#1250](https://github.com/Lightning-AI/metrics/pull/1250)) +- Fixed unintentional downloading of `nltk.punkt` when `lsum` not in `rouge_keys` ([#1258](https://github.com/Lightning-AI/metrics/pull/1258)) ## [0.10.0] - 2022-10-04 From dc63f8829baee3a22658bd38c522ff5f617d39f8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 10 Oct 2022 15:31:02 +0200 Subject: [PATCH 3/5] lower --- src/torchmetrics/functional/text/rouge.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index 0f6b0107fad..c2d1b104f3e 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -315,7 +315,8 @@ def _rouge_score_update( result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} list_results = [] pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer) - if "Lsum" in rouge_keys_values: + rouge_keys_values = {k.lower() for k in rouge_keys_values} + if "lsum" in rouge_keys_values: pred_lsum = [ _normalize_and_tokenize_text(pred_sentence, stemmer, normalizer, tokenizer) for pred_sentence in _split_sentence(pred_raw) From 9a7e0ae78b5858e46eeaf4ca4b03d70e5117f3a4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 10 Oct 2022 15:31:57 +0200 Subject: [PATCH 4/5] Apply suggestions from code review --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f0359ea2e57..6d33c26fcb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fixed broken clone method for classification metrics ([#1250](https://github.com/Lightning-AI/metrics/pull/1250)) + + - Fixed unintentional downloading of `nltk.punkt` when `lsum` not in `rouge_keys` ([#1258](https://github.com/Lightning-AI/metrics/pull/1258)) From be1f60455f3b6edb620ba9201bddecde3121242f Mon Sep 17 00:00:00 2001 From: stancld Date: Mon, 10 Oct 2022 21:35:42 +0200 Subject: [PATCH 5/5] Drop lower --- src/torchmetrics/functional/text/rouge.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index c2d1b104f3e..0f6b0107fad 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -315,8 +315,7 @@ def _rouge_score_update( result_avg: Dict[Union[int, str], List[Dict[str, Tensor]]] = {rouge_key: [] for rouge_key in rouge_keys_values} list_results = [] pred = _normalize_and_tokenize_text(pred_raw, stemmer, normalizer, tokenizer) - rouge_keys_values = {k.lower() for k in rouge_keys_values} - if "lsum" in rouge_keys_values: + if "Lsum" in rouge_keys_values: pred_lsum = [ _normalize_and_tokenize_text(pred_sentence, stemmer, normalizer, tokenizer) for pred_sentence in _split_sentence(pred_raw)