Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

[GoldDocFidAgent] clear mapping at the end of each eval_step rather than retrieve step #4503

Merged
merged 5 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion parlai/agents/fid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,13 @@ def __init__(self, opt: Opt, shared: TShared = None):
'GoldDocRetrieverFiDAgent only works with `rag_retriever_query` being `"full_history"`. '
f'Changing opt value for `rag_retriever_query`: `"{prev_sel}"` -> `"full_history"`'
)

if (
opt['dynamic_batching'] == 'full'
or opt.get('eval_dynamic_batching') == 'full'
):
raise RuntimeError(
"For now dynamic batching doesn't work with ObservationEchoRetriever as it cleans up _saved_docs mapping after each batch act."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that dynamic_batching can also be 'batchsort'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. fixed

super().__init__(opt, shared=shared)

@abstractmethod
Expand Down Expand Up @@ -376,6 +382,15 @@ def _set_query_vec(self, observation: Message) -> Message:
self.show_observation_to_echo_retriever(observation)
super()._set_query_vec(observation)

def batch_act(self, observations):
"""
Clear the _saved_docs and _query_ids mappings in ObservationEchoRetriever
"""
batch_reply = super().batch_act(observations)
if hasattr(self.model_api.retriever, 'clear_mapping'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a good solution. But what about cleaning on batchify that you suggested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize I can't do that in batchify, since model act flow is observe (insert k,v pairs to the _saved_doc and _query_ids) -> batchify -> act;
if I clear up the mapping during batchify then that literally clear up all the newly inserted kv pairs that are from the same batch.

Also notice another problem with this change (it won't reset the mapping in train step), let me update the pr and ping you guys again after it's ready.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we clear after act/batch_act from the FiD agent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea that's my sol rn.

self.model_api.retriever.clear_mapping()
return batch_reply


class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent):
"""
Expand Down
5 changes: 1 addition & 4 deletions parlai/agents/rag/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def tokenize_query(self, query: str) -> List[int]:
def get_delimiter(self) -> str:
return self._delimiter

def _clear_mapping(self):
def clear_mapping(self):
self._query_ids = dict()
self._saved_docs = dict()
self._largest_seen_idx = -1
Expand All @@ -1354,9 +1354,6 @@ def retrieve_and_score(
query.device
)

# empty the 2 mappings after each retrieval
self._clear_mapping()

return retrieved_docs, retrieved_doc_scores


Expand Down