Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Fix RAG generation for T5 FiD (#3657)
Browse files Browse the repository at this point in the history
I noticed a big discrepency between exact_match and token_em after I switched to using T5 FiD in public ParlAI.
Per discussion with Kurt, it's cause the generate for T5 is supposed to call TorchGeneratorAgent, rather than using the T5 generation via Hugging Face. This was in the original itnernal implementation, but got lost at some point while open sourcing.

Verified by printing out forced decoding outputs and the output from eval_model. Verified that when `token_em == 1` that the output of eval_model was also the same.
  • Loading branch information
moyapchen authored May 19, 2021
1 parent 3bf87ea commit 91e883b
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions parlai/agents/rag/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks.
Expand All @@ -27,7 +28,7 @@
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import History, Batch
from parlai.core.torch_generator_agent import PPLMetric, TreeSearch
from parlai.core.torch_generator_agent import PPLMetric, TorchGeneratorAgent, TreeSearch
from parlai.utils.distributed import sync_parameters
from parlai.utils.io import PathManager
import parlai.utils.logging as logging
Expand Down Expand Up @@ -75,6 +76,20 @@ class T5RagAgent(T5Agent, BaseGenerationAgentMixin):
def build_rag_model(opt: Opt, dictionary: DictionaryAgent) -> T5RagModel:
return T5RagModel(opt, dictionary)

def _generate(
self,
batch: Batch,
beam_size: int,
max_ts: int,
prefix_tokens: Optional[torch.LongTensor] = None,
) -> Tuple[List[Tuple[torch.LongTensor, torch.Tensor]], List[TreeSearch]]:
"""
Override since T5 needs to call TGA generate.
"""
return TorchGeneratorAgent._generate(
self, batch, beam_size, max_ts, prefix_tokens
)


GENERATION_AGENTS = {
'transformer/generator': TransformerGeneratorRagAgent,
Expand Down Expand Up @@ -805,7 +820,12 @@ def compute_loss(
scores, preds, enc_state, *_ = model_output

self._record_retrieval_metrics(batch, enc_state)
loss, metric_loss, metric_correct, metric_target_tokens = self._rag_model_interface.compute_loss(
(
loss,
metric_loss,
metric_correct,
metric_target_tokens,
) = self._rag_model_interface.compute_loss(
self.criterion, scores, preds, enc_state, batch.label_vec
)

Expand Down

0 comments on commit 91e883b

Please sign in to comment.