From 530de072a56943f18c81ddac29c04cc256eaa065 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Thu, 14 Apr 2022 14:39:23 -0700 Subject: [PATCH 1/5] clear mapping at the end --- parlai/agents/fid/fid.py | 6 ++++++ parlai/agents/rag/retrievers.py | 5 +---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/parlai/agents/fid/fid.py b/parlai/agents/fid/fid.py index 28ef2638360..a09b23514d4 100644 --- a/parlai/agents/fid/fid.py +++ b/parlai/agents/fid/fid.py @@ -376,6 +376,12 @@ def _set_query_vec(self, observation: Message) -> Message: self.show_observation_to_echo_retriever(observation) super()._set_query_vec(observation) + def eval_step(self, batch): + output = super().eval_step(batch) + if hasattr(self.model_api.retriever, 'clear_mapping'): + self.model_api.retriever.clear_mapping() + return output + class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent): """ diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index d164c015525..db2bbdc317b 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -1332,7 +1332,7 @@ def tokenize_query(self, query: str) -> List[int]: def get_delimiter(self) -> str: return self._delimiter - def _clear_mapping(self): + def clear_mapping(self): self._query_ids = dict() self._saved_docs = dict() self._largest_seen_idx = -1 @@ -1354,9 +1354,6 @@ def retrieve_and_score( query.device ) - # empty the 2 mappings after each retrieval - self._clear_mapping() - return retrieved_docs, retrieved_doc_scores From 760395143030e6ae8a5dfec9500adf64728bf77f Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Fri, 15 Apr 2022 09:31:18 -0700 Subject: [PATCH 2/5] apply to train step as well --- parlai/agents/fid/fid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/agents/fid/fid.py b/parlai/agents/fid/fid.py index a09b23514d4..63416372d4a 100644 --- a/parlai/agents/fid/fid.py +++ b/parlai/agents/fid/fid.py @@ -376,8 +376,8 @@ def _set_query_vec(self, observation: Message) -> Message: self.show_observation_to_echo_retriever(observation) super()._set_query_vec(observation) - def eval_step(self, batch): - output = super().eval_step(batch) + def batch_act(self, observations): + output = super().batch_act(observations) if hasattr(self.model_api.retriever, 'clear_mapping'): self.model_api.retriever.clear_mapping() return output From 5c49fcc5f69f5b46b812609fec72121ef22bc7fb Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Fri, 15 Apr 2022 09:39:00 -0700 Subject: [PATCH 3/5] dyn --- parlai/agents/fid/fid.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/parlai/agents/fid/fid.py b/parlai/agents/fid/fid.py index 63416372d4a..d98c5086d20 100644 --- a/parlai/agents/fid/fid.py +++ b/parlai/agents/fid/fid.py @@ -348,7 +348,13 @@ def __init__(self, opt: Opt, shared: TShared = None): 'GoldDocRetrieverFiDAgent only works with `rag_retriever_query` being `"full_history"`. ' f'Changing opt value for `rag_retriever_query`: `"{prev_sel}"` -> `"full_history"`' ) - + if ( + opt['dynamic_batching'] == 'full' + or opt.get('eval_dynamic_batching') == 'full' + ): + raise RuntimeError( + "For now dynamic batching doesn't work with ObservationEchoRetriever as it cleans up _saved_docs mapping after each batch act." + ) super().__init__(opt, shared=shared) @abstractmethod @@ -377,10 +383,13 @@ def _set_query_vec(self, observation: Message) -> Message: super()._set_query_vec(observation) def batch_act(self, observations): - output = super().batch_act(observations) + """ + Clear the _saved_docs and _query_ids mappings in ObservationEchoRetriever + """ + batch_reply = super().batch_act(observations) if hasattr(self.model_api.retriever, 'clear_mapping'): self.model_api.retriever.clear_mapping() - return output + return batch_reply class WizIntGoldDocRetrieverFiDAgent(GoldDocRetrieverFiDAgent): From 40e1a6c6f3f15a786521c2bef9c101af54d11a01 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Fri, 15 Apr 2022 10:24:08 -0700 Subject: [PATCH 4/5] comment --- parlai/agents/fid/fid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/agents/fid/fid.py b/parlai/agents/fid/fid.py index d98c5086d20..80934b68103 100644 --- a/parlai/agents/fid/fid.py +++ b/parlai/agents/fid/fid.py @@ -348,9 +348,9 @@ def __init__(self, opt: Opt, shared: TShared = None): 'GoldDocRetrieverFiDAgent only works with `rag_retriever_query` being `"full_history"`. ' f'Changing opt value for `rag_retriever_query`: `"{prev_sel}"` -> `"full_history"`' ) - if ( - opt['dynamic_batching'] == 'full' - or opt.get('eval_dynamic_batching') == 'full' + if not ( + opt['dynamic_batching'] in [None, 'off'] + and opt.get('eval_dynamic_batching') in [None, 'off'] ): raise RuntimeError( "For now dynamic batching doesn't work with ObservationEchoRetriever as it cleans up _saved_docs mapping after each batch act." From 298d93e9de08d05fa9f3895a22d4597013e56db4 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Fri, 15 Apr 2022 10:28:38 -0700 Subject: [PATCH 5/5] black --- parlai/agents/fid/fid.py | 2 +- parlai/agents/rag/retrievers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/agents/fid/fid.py b/parlai/agents/fid/fid.py index 80934b68103..b0c5c1bf1a3 100644 --- a/parlai/agents/fid/fid.py +++ b/parlai/agents/fid/fid.py @@ -384,7 +384,7 @@ def _set_query_vec(self, observation: Message) -> Message: def batch_act(self, observations): """ - Clear the _saved_docs and _query_ids mappings in ObservationEchoRetriever + Clear the _saved_docs and _query_ids mappings in ObservationEchoRetriever. """ batch_reply = super().batch_act(observations) if hasattr(self.model_api.retriever, 'clear_mapping'): diff --git a/parlai/agents/rag/retrievers.py b/parlai/agents/rag/retrievers.py index db2bbdc317b..7828ab40df1 100644 --- a/parlai/agents/rag/retrievers.py +++ b/parlai/agents/rag/retrievers.py @@ -1413,7 +1413,7 @@ def get_top_chunks( doc_url: str, ): """ - Return chunks according to the woi_chunk_retrieved_docs_mutator + Return chunks according to the woi_chunk_retrieved_docs_mutator. """ if isinstance(doc_chunks, list): docs = ''.join(doc_chunks)