Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[GoldDocFidAgent] clear mapping at the end of each eval_step rather than retrieve step #4503

Merged
merged 5 commits into from
Apr 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
@@ -348,7 +348,13 @@ def __init__(self, opt: Opt, shared: TShared = None):
'GoldDocRetrieverFiDAgent only works with `rag_retriever_query` being `"full_history"`. '
f'Changing opt value for `rag_retriever_query`: `"{prev_sel}"` -> `"full_history"`'
)

if not (
opt['dynamic_batching'] in [None, 'off']
and opt.get('eval_dynamic_batching') in [None, 'off']
):
raise RuntimeError(
"For now dynamic batching doesn't work with ObservationEchoRetriever as it cleans up _saved_docs mapping after each batch act."
)
super().__init__(opt, shared=shared)

@abstractmethod
@@ -376,6 +382,15 @@ def _set_query_vec(self, observation: Message) -> Message:
self.show_observation_to_echo_retriever(observation)
super()._set_query_vec(observation)

def batch_act(self, observations):
"""
Clear the _saved_docs and _query_ids mappings in ObservationEchoRetriever.
"""
batch_reply = super().batch_act(observations)
if hasattr(self.model_api.retriever, 'clear_mapping'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a good solution. But what about cleaning on batchify that you suggested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize I can't do that in batchify, since model act flow is observe (insert k,v pairs to the _saved_doc and _query_ids) -> batchify -> act;
if I clear up the mapping during batchify then that literally clear up all the newly inserted kv pairs that are from the same batch.

Also notice another problem with this change (it won't reset the mapping in train step), let me update the pr and ping you guys again after it's ready.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we clear after act/batch_act from the FiD agent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea that's my sol rn.

self.model_api.retriever.clear_mapping()
return batch_reply


class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent):
"""
7 changes: 2 additions & 5 deletions parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
@@ -1332,7 +1332,7 @@ def tokenize_query(self, query: str) -> List[int]:
def get_delimiter(self) -> str:
return self._delimiter

def _clear_mapping(self):
def clear_mapping(self):
self._query_ids = dict()
self._saved_docs = dict()
self._largest_seen_idx = -1
@@ -1354,9 +1354,6 @@ def retrieve_and_score(
query.device
)

# empty the 2 mappings after each retrieval
self._clear_mapping()

return retrieved_docs, retrieved_doc_scores


@@ -1416,7 +1413,7 @@ def get_top_chunks(
doc_url: str,
):
"""
Return chunks according to the woi_chunk_retrieved_docs_mutator
Return chunks according to the woi_chunk_retrieved_docs_mutator.
"""
if isinstance(doc_chunks, list):
docs = ''.join(doc_chunks)