From 6373397bac8f29676089c4b90de9421b44b756b3 Mon Sep 17 00:00:00 2001 From: Asugawara Date: Mon, 26 Sep 2022 01:30:36 +0900 Subject: [PATCH 1/3] avoid to loop through the whole counter in bleu_score method --- torchtext/data/metrics.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchtext/data/metrics.py b/torchtext/data/metrics.py index 76fc09d9e9..6472b1c809 100644 --- a/torchtext/data/metrics.py +++ b/torchtext/data/metrics.py @@ -65,11 +65,12 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4) refs_len = 0.0 for (candidate, refs) in zip(candidate_corpus, references_corpus): - candidate_len += len(candidate) + current_candidate_len = len(candidate) + candidate_len += current_candidate_len # Get the length of the reference that's closest in length to the candidate refs_len_list = [float(len(ref)) for ref in refs] - refs_len += min(refs_len_list, key=lambda x: abs(len(candidate) - x)) + refs_len += min(refs_len_list, key=lambda x: abs(current_candidate_len - x)) reference_counters = _compute_ngram_counter(refs[0], max_n) for ref in refs[1:]: @@ -79,11 +80,11 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4) clipped_counter = candidate_counter & reference_counters - for ngram in clipped_counter: - clipped_counts[len(ngram) - 1] += clipped_counter[ngram] + for ngram, count in clipped_counter.items(): + clipped_counts[len(ngram) - 1] += count - for ngram in candidate_counter: # TODO: no need to loop through the whole counter - total_counts[len(ngram) - 1] += candidate_counter[ngram] + for i in range(max_n): + total_counts[i] += current_candidate_len - i if min(clipped_counts) == 0: return 0.0 From ad3903cc97216b0b61f063f2fc2fb8c740c4c1a8 Mon Sep 17 00:00:00 2001 From: Asugawara Date: Tue, 27 Sep 2022 00:26:08 +0900 Subject: [PATCH 2/3] fix bug when max_n > len(candidate) --- torchtext/data/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/data/metrics.py b/torchtext/data/metrics.py index 6472b1c809..82778d819c 100644 --- a/torchtext/data/metrics.py +++ b/torchtext/data/metrics.py @@ -84,7 +84,7 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4) clipped_counts[len(ngram) - 1] += count for i in range(max_n): - total_counts[i] += current_candidate_len - i + total_counts[i] += max(current_candidate_len - i, 0) if min(clipped_counts) == 0: return 0.0 From e306408abdb2f85730ba624056dfba5ece533a65 Mon Sep 17 00:00:00 2001 From: Asugawara Date: Tue, 27 Sep 2022 23:25:18 +0900 Subject: [PATCH 3/3] add comment to explain L88 --- torchtext/data/metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtext/data/metrics.py b/torchtext/data/metrics.py index 82778d819c..ff21fa7d0a 100644 --- a/torchtext/data/metrics.py +++ b/torchtext/data/metrics.py @@ -84,6 +84,7 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4) clipped_counts[len(ngram) - 1] += count for i in range(max_n): + # The number of N-grams in a `candidate` of T tokens is `T - (N - 1)` total_counts[i] += max(current_candidate_len - i, 0) if min(clipped_counts) == 0: