Skip to content

Commit

Permalink
Avoid looping through the whole counter in bleu_score method (#1913)
Browse files Browse the repository at this point in the history
* avoid to loop through the whole counter in bleu_score method

* fix bug when max_n > len(candidate)

* add comment to explain L88
  • Loading branch information
Asugawara authored Sep 27, 2022
1 parent 766cf9d commit 5c48f4a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions torchtext/data/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,12 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4)
refs_len = 0.0

for (candidate, refs) in zip(candidate_corpus, references_corpus):
candidate_len += len(candidate)
current_candidate_len = len(candidate)
candidate_len += current_candidate_len

# Get the length of the reference that's closest in length to the candidate
refs_len_list = [float(len(ref)) for ref in refs]
refs_len += min(refs_len_list, key=lambda x: abs(len(candidate) - x))
refs_len += min(refs_len_list, key=lambda x: abs(current_candidate_len - x))

reference_counters = _compute_ngram_counter(refs[0], max_n)
for ref in refs[1:]:
Expand All @@ -79,11 +80,12 @@ def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4)

clipped_counter = candidate_counter & reference_counters

for ngram in clipped_counter:
clipped_counts[len(ngram) - 1] += clipped_counter[ngram]
for ngram, count in clipped_counter.items():
clipped_counts[len(ngram) - 1] += count

for ngram in candidate_counter: # TODO: no need to loop through the whole counter
total_counts[len(ngram) - 1] += candidate_counter[ngram]
for i in range(max_n):
# The number of N-grams in a `candidate` of T tokens is `T - (N - 1)`
total_counts[i] += max(current_candidate_len - i, 0)

if min(clipped_counts) == 0:
return 0.0
Expand Down

0 comments on commit 5c48f4a

Please sign in to comment.