diff --git a/parlai/tasks/wizard_of_internet/agents.py b/parlai/tasks/wizard_of_internet/agents.py index edab711ca98..d1292c1f09d 100644 --- a/parlai/tasks/wizard_of_internet/agents.py +++ b/parlai/tasks/wizard_of_internet/agents.py @@ -687,6 +687,54 @@ def many_episode_mutation(self, episode): return out_episodes +def chunk_docs_in_message(message, chunk_sz): + if CONST.RETRIEVED_DOCS not in message: + return message + new_message = message.copy() + docs = message.get(CONST.RETRIEVED_DOCS) + titles = message.get(CONST.RETRIEVED_DOCS_TITLES) + urls = message.get(CONST.RETRIEVED_DOCS_URLS) + new_docs = [] + new_titles = [] + new_urls = [] + checked_sentences = message.get(CONST.SELECTED_SENTENCES) + for i in range(len(checked_sentences)): + checked_sentences[i] = checked_sentences[i].lstrip(' ').rstrip(' ') + if ' '.join(checked_sentences) == CONST.NO_SELECTED_SENTENCES_TOKEN: + checked_sentences = [] + for ind in range(len(docs)): + d = docs[ind] + # Guarantees that checked sentences are not split in half (as we split by space). + for i in range(len(checked_sentences)): + d = d.replace(checked_sentences[i], "||CHECKED_SENTENCE_" + str(i) + "||") + while True: + end_chunk = d.find(' ', chunk_sz) + if end_chunk == -1: + # last chunk + for i in range(len(checked_sentences)): + d = d.replace( + "||CHECKED_SENTENCE_" + str(i) + "||", checked_sentences[i] + ) + new_docs.append(d) + new_titles.append(titles[ind]) + new_urls.append(urls[ind]) + break + else: + new_d = d[0:end_chunk] + for i in range(len(checked_sentences)): + new_d = new_d.replace( + "||CHECKED_SENTENCE_" + str(i) + "||", checked_sentences[i] + ) + new_docs.append(new_d) + new_titles.append(titles[ind]) + new_urls.append(urls[ind]) + d = d[end_chunk + 1 : -1] + new_message.force_set(CONST.RETRIEVED_DOCS, new_docs) + new_message.force_set(CONST.RETRIEVED_DOCS_TITLES, new_titles) + new_message.force_set(CONST.RETRIEVED_DOCS_URLS, new_urls) + return new_message + + @register_mutator("woi_chunk_retrieved_docs") class WoiChunkRetrievedDocs(MessageMutator): """ @@ -705,25 +753,8 @@ def add_cmdline_args( ) def message_mutation(self, message: Message) -> Message: - if CONST.RETRIEVED_DOCS not in message: - return message - new_message = message.copy() - docs = message.get(CONST.RETRIEVED_DOCS) - new_docs = [] chunk_sz = self.opt.get('woi_doc_chunk_size') - for doc in docs: - d = doc - while True: - end_chunk = d.find(' ', chunk_sz) - if end_chunk == -1: - # last chunk - new_docs.append(d) - break - else: - new_docs.append(d[0:end_chunk]) - d = d[end_chunk + 1 : -1] - new_message.force_set(CONST.RETRIEVED_DOCS, new_docs) - return new_message + return chunk_docs_in_message(message, chunk_sz) @register_mutator("woi_dropout_retrieved_docs")