From 4e085e6a62fba2cd1742a0b564086e2c6bc02991 Mon Sep 17 00:00:00 2001 From: Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com> Date: Wed, 22 Sep 2021 16:25:50 +0700 Subject: [PATCH] Update BART hub_interface To provide an extra option to convert OOV tokens into rather than always adding the OOV tokens into the dictionary. --- fairseq/models/bart/hub_interface.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 4d47d97518..d20e99e3ae 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -31,7 +31,11 @@ def __init__(self, cfg, task, model): self.model = self.models[0] def encode( - self, sentence: str, *addl_sentences, no_separator=True + self, + sentence: str, + *addl_sentences, + no_separator=True, + add_if_not_exist=True ) -> torch.LongTensor: """ BPE-encode a sentence (or multiple sentences). @@ -59,7 +63,9 @@ def encode( for s in addl_sentences: bpe_sentence += " " if not no_separator else "" bpe_sentence += " " + self.bpe.encode(s) + " " - tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False) + tokens = self.task.source_dictionary.encode_line( + bpe_sentence, append_eos=False, add_if_not_exist=add_if_not_exist + ) return tokens.long() def decode(self, tokens: torch.LongTensor):