diff --git a/torchtext/data/metrics.py b/torchtext/data/metrics.py index 76fc09d9e9..ff21fa7d0a 100644 --- a/torchtext/data/metrics.py +++ b/torchtext/data/metrics.py @@ -65,11 +65,12 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4) refs_len = 0.0 for (candidate, refs) in zip(candidate_corpus, references_corpus): - candidate_len += len(candidate) + current_candidate_len = len(candidate) + candidate_len += current_candidate_len # Get the length of the reference that's closest in length to the candidate refs_len_list = [float(len(ref)) for ref in refs] - refs_len += min(refs_len_list, key=lambda x: abs(len(candidate) - x)) + refs_len += min(refs_len_list, key=lambda x: abs(current_candidate_len - x)) reference_counters = _compute_ngram_counter(refs[0], max_n) for ref in refs[1:]: @@ -79,11 +80,12 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4) clipped_counter = candidate_counter & reference_counters - for ngram in clipped_counter: - clipped_counts[len(ngram) - 1] += clipped_counter[ngram] + for ngram, count in clipped_counter.items(): + clipped_counts[len(ngram) - 1] += count - for ngram in candidate_counter: # TODO: no need to loop through the whole counter - total_counts[len(ngram) - 1] += candidate_counter[ngram] + for i in range(max_n): + # The number of N-grams in a `candidate` of T tokens is `T - (N - 1)` + total_counts[i] += max(current_candidate_len - i, 0) if min(clipped_counts) == 0: return 0.0