Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Wizard of Internet] Gold retrieved documents FiD Agent. #4123

Merged
merged 11 commits into from
Nov 2, 2021

Conversation

mojtaba-komeili
Copy link
Contributor

Patch description
GoldDocRetrieverFiDAgent: It uses a custom retriever that instead of retrieving from a corpus, fetches documents included in the example itself. For the Wizard of Internet dataset this is the documents that were shown to the crowdsources at the time of the data collection (Hence, gold retrieved docs).

@mojtaba-komeili
Copy link
Contributor Author

Haven't finished the document chunk splitter yet. Will add that soon as well.

@mojtaba-komeili mojtaba-komeili changed the title reformat [Wizard of Internet] Gold retrieved documents FiD Agent. Oct 29, 2021
parlai/agents/rag/retrievers.py Outdated Show resolved Hide resolved
return self._delimiter

def pick_chunk(self, doc_text: str, selected_sentences: List[str]):
# TODO replace with a better doc chunk selector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have an arg for doc chunk length, right? could probably use this to start

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is still a place-holder until the rest of the code works. Trying to resolve another issue with the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the code for using the data from a teacher? yes, we won't need this if so because we can use the mutator i just wrote?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, yeah, let's have no model-side logic for this, don't want any unintended side effects

super().__init__(opt, shared=shared)

def observe(self, observation: Union[Dict, Message]) -> Message:
self.model.retriever.set_retrieve_doc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd check if 'text_vec' not in Batch before setting, so we can skip padding exapmles

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should check for padding too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for model_api

@mojtaba-komeili mojtaba-komeili marked this pull request as ready for review October 29, 2021 15:05
Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming we'll want this retriever to handle batching, yeah?

I think the correct way to do this would be to handle within batchify, rather than observe; at that point, we have access to all observations, so we can set the retrieved documents for each observation (you can check BlenderBot2RagAgent.batchify to see how this is handled)

parlai/agents/fid/fid.py Outdated Show resolved Hide resolved
return observation

if not observation.get(consts.RETRIEVED_DOCS):
self.model.retriever.set_retrieve_doc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do self.model_api.retriever here? that's a protection against dataparallel

Copy link
Contributor Author

@mojtaba-komeili mojtaba-komeili Oct 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I move the model_api to RagAgent. That should cover it for BB2 as well, so I remove it from there. Let's see how many things are gonna break now 😃

super().__init__(opt, shared=shared)

def observe(self, observation: Union[Dict, Message]) -> Message:
self.model.retriever.set_retrieve_doc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for model_api

retrieved_docs=None, selected_docs=None, selected_sentences=None
)
else:
rertrieved_docs = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "retrieved*"

parlai/agents/fid/fid.py Outdated Show resolved Hide resolved
parlai/agents/fid/fid.py Outdated Show resolved Hide resolved

def retrieve_and_score(
self, query: torch.LongTensor
) -> Tuple[List[List[Document]], torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm so this still doesn't handle batching?

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, batching looks great, although I had one minor concern related to the uniqueness of the queries...

have you encountered any issues with the same queries mapping to 2 sets of documents? if not, perhaps we can just leave as is and make a warning or something?

parlai/agents/rag/rag.py Show resolved Hide resolved
parlai/agents/rag/retrievers.py Outdated Show resolved Hide resolved

def add_retrieve_doc(self, query: str, retrieved_docs: List[Document]):
new_idx = len(self._querie_ids)
self._querie_ids[query] = new_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if query is already in querie_ids?

we might want to make this more robust --> maybe hash the observation? or use the query<sep>label?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e., lots of conversation starters can be similar for some datasets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. I believe we I checked this for WizInt and we didn't have collisions.
My reason for ignoring it was that it makes sense for them to have the same knowledge if the conversation is the same. Given that retrieve_and_score only sees the query (not the label) it gets slightly more complicated to include the label for the query encoding of this agent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i suppose we're operating under the (strong) assumption that the query is going to be the full dialogue context, and not the last turn of dialogue

for example, one could imagine a mid-conversation utterance, "that's so cool!", to elicit different documents depending on the preceding dialogue context

parlai/agents/rag/retrievers.py Outdated Show resolved Hide resolved
@mojtaba-komeili
Copy link
Contributor Author

One more major change: the GoldDocRetrieverFiDAgent was too specific to wizard of internet. Created a new abstract class with that name for general use later. The new WizIntGoldDocRetrieverFiDAgent is replacing it now.

Comment on lines 347 to 355
if observation.get(consts.RETRIEVED_DOCS):
for doc_id, doc_title, doc_txt in zip(
observation[consts.RETRIEVED_DOCS_URLS],
observation[consts.RETRIEVED_DOCS_TITLES],
observation[consts.RETRIEVED_DOCS],
):
retrieved_docs.append(
Document(docid=doc_id, title=doc_title, text=doc_txt)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is redundant now, right?

def __init__(self, opt: Opt, shared: TShared = None):
opt = deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.OBSERVATION_ECHO_RETRIEVER.value
if opt['rag_retriever_query'] != 'full_history':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@klshuster regarding our previous conversation on possibility of collision between context for document retrieval, I see this as an easy solution to avoid that. I believe forcing the use of full history for this type of agent makes sense, given that the whole retrieval process is recalling some past conversation. Let me know what you think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, makes sense to me

@mojtaba-komeili mojtaba-komeili merged commit 06ac02d into main Nov 2, 2021
@mojtaba-komeili mojtaba-komeili deleted the retrieve_gold branch November 2, 2021 14:44
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants