Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Wizard of Internet] Gold retrieved documents FiD Agent. #4123

Merged
merged 11 commits into from
Nov 2, 2021
28 changes: 26 additions & 2 deletions parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from parlai.core.dict import DictionaryAgent
from parlai.core.opt import Opt
from parlai.core.message import Message
from parlai.agents.transformer.transformer import TransformerGeneratorModel

from parlai.agents.rag.args import RetrieverType
Expand All @@ -25,6 +26,7 @@
fix_incremental_state,
)
from parlai.utils.typing import TShared
from parlai.tasks.wizard_of_internet import constants as consts


class Fid(RagToken):
Expand Down Expand Up @@ -92,8 +94,8 @@ def reorder_decoder_incremental_state(
self, incremental_state: Dict[int, dict], inds: torch.Tensor
) -> Dict[int, dict]:
"""
Override RagModel.reorder_decoder_incremental_state to resort back
to normal reordering.
Override RagModel.reorder_decoder_incremental_state to resort back to normal
reordering.

See ``TorchGeneratorModel.reorder_decoder_incremental_state`` for a description.
"""
Expand Down Expand Up @@ -318,6 +320,28 @@ def __init__(self, opt: Opt, shared: TShared = None):
super().__init__(opt, shared=shared)


class GoldDocRetrieverFiDAgent(SearchQueryFiDAgent):
"""
Uses the gold retrived docs (documents shown to crowdsourcing agents).
mojtaba-komeili marked this conversation as resolved.
Show resolved Hide resolved

This FiD agents has a mock retriever that picks the retrieved docs from the observed
example.
"""

def __init__(self, opt: Opt, shared: TShared = None):
opt = deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.OBSERVATION_ECHO_RETRIEVER.value
super().__init__(opt, shared=shared)

def observe(self, observation: Union[Dict, Message]) -> Message:
self.model.retriever.set_retrieve_doc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd check if 'text_vec' not in Batch before setting, so we can skip padding exapmles

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should check for padding too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for model_api

retrieved_docs=observation.get(consts.RETRIEVED_DOCS, ['']),
selected_docs=observation.get(consts.SELECTED_DOCS, ['']),
selected_sentences=observation.get(consts.SELECTED_SENTENCES, ['']),
mojtaba-komeili marked this conversation as resolved.
Show resolved Hide resolved
)
return super().observe(observation)


def concat_enc_outs(
input: torch.LongTensor,
enc_out: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions parlai/agents/rag/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class RetrieverType(Enum):
POLY_FAISS = 'poly_faiss'
SEARCH_ENGINE = 'search_engine'
SEARCH_TERM_FAISS = 'search_term_faiss'
OBSERVATION_ECHO_RETRIEVER = 'observation_echo_retriever'


def setup_rag_args(parser: ParlaiParser) -> ParlaiParser:
Expand Down
65 changes: 64 additions & 1 deletion parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,13 @@ def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None
super().__init__()
self.retriever_type = RetrieverType(opt['rag_retriever_type'])
if not (
(self.retriever_type == RetrieverType.SEARCH_ENGINE)
(
self.retriever_type
in (
RetrieverType.SEARCH_ENGINE,
RetrieverType.OBSERVATION_ECHO_RETRIEVER,
)
)
or (opt.get('retriever_debug_index') in [None, 'none'])
):
if opt.get('retriever_debug_index') == 'exact':
Expand Down Expand Up @@ -1267,6 +1273,61 @@ def retrieve_and_score(
return top_docs, top_doc_scores


class ObservationEchoRetriever(RagRetriever):
"""
This retriever returns (echos) documents that are already passed to it to return.

Use this only with GoldFiD agents. It relies on the retrieved docs being included in
the observed example of the agent.
"""

def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared: TShared = None):
self._retrieved_docs = None
self._selected_docs = None
self._selected_sentences = None
self._delimiter = '\n'
super().__init__(opt, dictionary, shared=shared)

def set_retrieve_doc(
self,
retrieved_docs: List[str],
selected_docs: List[str],
selected_sentences: List[str],
):
self._retrieved_docs = retrieved_docs
self._selected_docs = selected_docs
self._selected_sentences = selected_sentences

def get_delimiter(self) -> str:
return self._delimiter

def pick_chunk(self, doc_text: str, selected_sentences: List[str]):
# TODO replace with a better doc chunk selector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have an arg for doc chunk length, right? could probably use this to start

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is still a place-holder until the rest of the code works. Trying to resolve another issue with the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the code for using the data from a teacher? yes, we won't need this if so because we can use the mutator i just wrote?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, yeah, let's have no model-side logic for this, don't want any unintended side effects

return doc_text[:256]

def retrieve_and_score(
self, query: torch.LongTensor
) -> Tuple[List[List[Document]], torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm so this still doesn't handle batching?

# Some arbitrary scoring of docs
# breakpoint()
assert query.size(0) == 1, 'This retriever only handles a single example batch.'
klshuster marked this conversation as resolved.
Show resolved Hide resolved
retrieved_docs, retrieved_doc_scores = [], []
for idx in range(len(self._retrieved_docs)):
retrieved_docs.append(
Document(
docid=f'id_{idx}',
text=self.pick_chunk(
self._retrieved_docs[idx], self._selected_sentences
),
title=f'title_{idx}',
)
)
return (
[retrieved_docs],
torch.Tensor(retrieved_doc_scores).reshape(1, -1).to(query.device),
)


class DocumentChunkRanker:
"""
Base class for controlling splitting long documents and selecting relevant chunks.
Expand Down Expand Up @@ -1344,3 +1405,5 @@ def retriever_factory(
return SearchQuerySearchEngineRetriever(opt, dictionary, shared=shared)
elif retriever is RetrieverType.SEARCH_TERM_FAISS:
return SearchQueryFAISSIndexRetriever(opt, dictionary, shared=shared)
elif retriever is RetrieverType.OBSERVATION_ECHO_RETRIEVER:
return ObservationEchoRetriever(opt, dictionary, shared=shared)