Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
FiD gold retrieved docs agent n-docs debug (#4146)
Browse files Browse the repository at this point in the history
* trimming docs down to n_docs

* added debug message

* random shuffling of docs

* pr comments

* pr comments 2
  • Loading branch information
mojtaba-komeili authored Nov 8, 2021
1 parent 39dd264 commit 825a057
Showing 1 changed file with 64 additions and 11 deletions.
75 changes: 64 additions & 11 deletions parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"""
from abc import abstractmethod
from copy import deepcopy
from enum import unique
import torch
import random
from typing import Tuple, Union, Optional, List, Dict, Any

from parlai.core.dict import DictionaryAgent
Expand Down Expand Up @@ -334,6 +334,7 @@ class GoldDocRetrieverFiDAgent(SearchQueryFiDAgent):
def __init__(self, opt: Opt, shared: TShared = None):
opt = deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.OBSERVATION_ECHO_RETRIEVER.value
self._n_docs = opt['n_docs']
if opt['rag_retriever_query'] != 'full_history':
prev_sel = opt['rag_retriever_query']
opt['rag_retriever_query'] = 'full_history'
Expand All @@ -352,6 +353,15 @@ def get_retrieved_knowledge(self, message):

def _set_query_vec(self, observation: Message) -> Message:
retrieved_docs = self.get_retrieved_knowledge(observation)
if len(retrieved_docs) > self._n_docs:
logging.warning(
f'Your `get_retrieved_knowledge` method returned {len(retrieved_docs)} Documents, '
f'instead of the expected {self._n_docs} (set by `--n-docs`). '
f'This agent will only use the first {self._n_docs} Documents. '
'Consider modifying your implementation of `get_retrieved_knowledge` to avoid unexpected results. '
'(or alternatively you may increase `--n-docs` parameter)'
)
retrieved_docs = retrieved_docs[: self._n_docs]
self.model_api.retriever.add_retrieve_doc(
observation[self._query_key], retrieved_docs
)
Expand All @@ -363,17 +373,60 @@ class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent):
Gold knowledge FiD agent for the Wizard of Internet task.
"""

def get_retrieved_knowledge(self, message):
def _extract_doc_from_message(self, message: Message, idx: int):
"""
Returns the `idx`-th `__retrieved-docs__` in the `message` as a Document object.
"""
return Document(
docid=message[consts.RETRIEVED_DOCS_URLS][idx],
title=message[consts.RETRIEVED_DOCS_TITLES][idx],
text=message[consts.RETRIEVED_DOCS][idx],
)

def get_retrieved_knowledge(self, message: Message):

retrieved_docs = []
if message.get(consts.RETRIEVED_DOCS):
for doc_id, doc_title, doc_txt in zip(
message[consts.RETRIEVED_DOCS_URLS],
message[consts.RETRIEVED_DOCS_TITLES],
message[consts.RETRIEVED_DOCS],
):
retrieved_docs.append(
Document(docid=doc_id, title=doc_title, text=doc_txt)
)
if not message.get(consts.RETRIEVED_DOCS):
return retrieved_docs

# First adding the docs with selected sentences.
selected_sentences = message[consts.SELECTED_SENTENCES]
n_docs_in_message = len(message[consts.RETRIEVED_DOCS])
already_added_doc_idx = []

if ' '.join(selected_sentences) != consts.NO_SELECTED_SENTENCES_TOKEN:
for doc_idx in range(n_docs_in_message):
doc_content = message[consts.RETRIEVED_DOCS][doc_idx]
for sel_sentc in selected_sentences:
if sel_sentc in doc_content:
retrieved_docs.append(
self._extract_doc_from_message(message, doc_idx)
)
already_added_doc_idx.append(doc_idx)
break
if len(retrieved_docs) == self._n_docs:
logging.warning(
f'More than {self._n_docs} documents have selected sentences. Trimming them to the first {self._n_docs}'
)
break

# Then adding other (filler) docs.
# We add them by iterating forward in the __retrieved-docs__ list for repeatability,
# but we shuffle the order of the final retruned docs, to make sure model doesn't cheat.
for doc_idx in range(n_docs_in_message):
if len(retrieved_docs) == self._n_docs:
break

if doc_idx in already_added_doc_idx:
continue

retrieved_docs.append(self._extract_doc_from_message(message, doc_idx))

if n_docs_in_message > len(retrieved_docs):
logging.debug(
f'Trimmed retrieved docs from {n_docs_in_message} to {len(retrieved_docs)}'
)
random.shuffle(retrieved_docs)
return retrieved_docs


Expand Down

0 comments on commit 825a057

Please sign in to comment.