Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

FiD gold retrieved docs agent n-docs debug #4146

Merged
merged 5 commits into from
Nov 8, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"""
from abc import abstractmethod
from copy import deepcopy
from enum import unique
import torch
import random
from typing import Tuple, Union, Optional, List, Dict, Any

from parlai.core.dict import DictionaryAgent
Expand Down Expand Up @@ -334,6 +334,7 @@ class GoldDocRetrieverFiDAgent(SearchQueryFiDAgent):
def __init__(self, opt: Opt, shared: TShared = None):
opt = deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.OBSERVATION_ECHO_RETRIEVER.value
self._n_docs = opt['n_docs']
if opt['rag_retriever_query'] != 'full_history':
prev_sel = opt['rag_retriever_query']
opt['rag_retriever_query'] = 'full_history'
Expand Down Expand Up @@ -363,17 +364,59 @@ class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent):
Gold knowledge FiD agent for the Wizard of Internet task.
"""

def get_retrieved_knowledge(self, message):
def _extract_doc_from_message(self, message: Message, idx: int):
"""
Returns the `idx`-th `__retrieved-docs__` in the `message` as a Document object.
"""
return Document(
docid=message[consts.RETRIEVED_DOCS_URLS][idx],
title=message[consts.RETRIEVED_DOCS_TITLES][idx],
text=message[consts.RETRIEVED_DOCS][idx],
)

def get_retrieved_knowledge(self, message: Message):

retrieved_docs = []
if message.get(consts.RETRIEVED_DOCS):
for doc_id, doc_title, doc_txt in zip(
message[consts.RETRIEVED_DOCS_URLS],
message[consts.RETRIEVED_DOCS_TITLES],
message[consts.RETRIEVED_DOCS],
):
retrieved_docs.append(
Document(docid=doc_id, title=doc_title, text=doc_txt)
)
if not message.get(consts.RETRIEVED_DOCS):
return retrieved_docs

# First adding the docs with selected sentences.
selected_sentences = message[consts.SELECTED_SENTENCES]
n_docs_in_message = len(message[consts.RETRIEVED_DOCS])
already_added_doc_idx = []

if ' '.join(selected_sentences) != consts.NO_SELECTED_SENTENCES_TOKEN:
for doc_idx in range(n_docs_in_message):
doc_content = message[consts.RETRIEVED_DOCS][doc_idx]
for sel_sentc in selected_sentences:
if sel_sentc in doc_content:
retrieved_docs.append(
self._extract_doc_from_message(message, doc_idx)
)
already_added_doc_idx.append(doc_idx)
break
if len(retrieved_docs) == self._n_docs:
logging.warning(
f'More than {self._n_docs} documents have selected sentences. Trimming them to the first {self._n_docs}'
)
break

# Then adding other (filler) docs.
# We add them by iterating forward in the __retrieved-docs__ list for repeatability,
# but we shuffle the order of the final retruned docs, to make sure model doesn't cheat.
for doc_idx in range(n_docs_in_message):
if doc_idx in already_added_doc_idx:
continue

retrieved_docs.append(self._extract_doc_from_message(message, doc_idx))
if len(retrieved_docs) == self._n_docs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not checked before this for-loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we gain much by adding it before the for loop. It also makes the code a bit more succinct by having the loop that breaks right after start. But I can move it to the top of the loop to avoid extra check on already_added_doc_idx.

break

if n_docs_in_message > len(retrieved_docs):
logging.debug(
f'Trimmed retrieved docs from {n_docs_in_message} to {len(retrieved_docs)}'
)
random.shuffle(retrieved_docs)
return retrieved_docs


Expand Down