Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Improve parlai cold start (#3482)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller authored Mar 2, 2021
1 parent 7092d77 commit a467bca
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 37 deletions.
40 changes: 16 additions & 24 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,6 @@
ALL_METRICS = DEFAULT_METRICS | ROUGE_METRICS | BLEU_METRICS | DISTINCT_METRICS


try:
from nltk.translate import bleu_score as nltkbleu
except ImportError:
# User doesn't have nltk installed, so we can't use it for bleu
# We'll just turn off things, but we might want to warn the user
nltkbleu = None

try:
from fairseq.scoring import bleu as fairseqbleu
except ImportError:
fairseqbleu = None

try:
import rouge
except ImportError:
# User doesn't have py-rouge installed, so we can't use it.
# We'll just turn off rouge computations
rouge = None

re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')

Expand Down Expand Up @@ -454,9 +435,13 @@ def compute(guess: str, answers: List[str], k: int = 4) -> Optional[BleuMetric]:
"""
Compute approximate BLEU score between guess and a set of answers.
"""
if nltkbleu is None:
# bleu library not installed, just return a default value
try:
from nltk.translate import bleu_score as nltkbleu
except ImportError:
# User doesn't have nltk installed, so we can't use it for bleu
# We'll just turn off things, but we might want to warn the user
return None

# Warning: BLEU calculation *should* include proper tokenization and
# punctuation etc. We're using the normalize_answer for everything though,
# so we're over-estimating our BLEU scores. Also note that NLTK's bleu is
Expand All @@ -481,8 +466,11 @@ def compute_many(
"""
Return BLEU-1..4 using fairseq and tokens.
"""
if fairseqbleu is None:
try:
from fairseq.scoring import bleu as fairseqbleu
except ImportError:
return None

scorer = fairseqbleu.Scorer(pad_idx, end_idx, unk_idx)
answers = answers.cpu().int()
guess = guess.cpu().int()
Expand All @@ -505,9 +493,13 @@ def compute_many(
:return: (rouge-1, rouge-2, rouge-L)
"""
# possible global initialization
global rouge
if rouge is None:
try:
import rouge
except ImportError:
# User doesn't have py-rouge installed, so we can't use it.
# We'll just turn off rouge computations
return None, None, None

if RougeMetric._evaluator is None:
RougeMetric._evaluator = rouge.Rouge(
metrics=['rouge-n', 'rouge-l'], max_n=2
Expand Down
13 changes: 0 additions & 13 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,6 @@
)


try:
from nltk.translate import bleu_score as nltkbleu

except ImportError:
nltkbleu = None

try:
from fairseq.scoring import bleu as fairseq_bleu

except ImportError:
fairseq_bleu = None


class SearchBlocklist(object):
"""
Search block list facilitates blocking ngrams from being generated.
Expand Down

0 comments on commit a467bca

Please sign in to comment.