From 1fa8f7a8034a394149b1c9f7ead2a021bf08518a Mon Sep 17 00:00:00 2001 From: Jing Date: Tue, 23 Nov 2021 15:57:09 -0500 Subject: [PATCH] flush (#4198) --- projects/blenderbot2/agents/modules.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/projects/blenderbot2/agents/modules.py b/projects/blenderbot2/agents/modules.py index e0c12514ef3..0f8343e3f42 100644 --- a/projects/blenderbot2/agents/modules.py +++ b/projects/blenderbot2/agents/modules.py @@ -295,6 +295,14 @@ def get_retrieval_indices( assert self.knowledge_access_method is KnowledgeAccessMethod.CLASSIFY return type_indices + def flush_previous_retriever_search_results(self): + if not hasattr(self, 'retriever'): + return + if hasattr(self.retriever, 'top_docs'): + delattr(self.retriever, 'top_docs') + if hasattr(self.retriever, 'search_queries'): + delattr(self.retriever, 'search_queries') + def retrieve_and_concat( self, input: torch.LongTensor, @@ -314,6 +322,7 @@ def retrieve_and_concat( Override RagModel.retrieve_and_concat to perform different retrieval, depending on the RetrieverType. """ + self.flush_previous_retriever_search_results() start = time.time() logging.debug(f'Begin encoder: {time.time() - start:.2f}') if input_turns_cnt is not None: