From ad8dcdf6e638d7a4230d2009d9275cef48fa32ea Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Wed, 8 Sep 2021 10:13:52 -0400 Subject: [PATCH 1/5] wizint knol pred teacher --- parlai/tasks/wizard_of_internet/agents.py | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/parlai/tasks/wizard_of_internet/agents.py b/parlai/tasks/wizard_of_internet/agents.py index 2790fff20e7..81007a561b1 100644 --- a/parlai/tasks/wizard_of_internet/agents.py +++ b/parlai/tasks/wizard_of_internet/agents.py @@ -571,3 +571,39 @@ def _knowledge_piece(self): class GoldDocTitlesTeacher(BaseKnowledgeTeacher): def _knowledge_piece(self): return CONST.SELECTED_DOCS_TITLES + + +class PredictKnowledgeGivenLabelTeacher(WizardOfInternetBaseTeacher): + def __init__(self, opt, shared=None): + super().__init__(opt, shared=shared) + self.id = 'PredictKnowledgeGivenLabelTeacher' + + def _teacher_action_type(self) -> str: + return CONST.ACTION_WIZARD_DOC_SELECTION + + def _knowledge_piece(self): + return CONST.SELECTED_SENTENCES + + def additional_message_content(self, parlai_message: Message, action: Dict): + for item_key in ( + CONST.SELECTED_DOCS, + CONST.SELECTED_DOCS_TITLES, + CONST.SELECTED_SENTENCES, + ): + parlai_message[item_key] = action[item_key] + + def create_parlai_message(self, dict_message: Dict): + parlai_msg = Message( + { + CONST.SPEAKER_ID: dict_message[CONST.SPEAKER_ID], + # CONST.MESSAGE_TEXT: dict_message[CONST.MESSAGE_TEXT] + "\n _label_ " + + CONST.LABELS: [' '.join(dict_message[CONST.SELECTED_SENTENCES])], + } + ) + prv_msg = dict_message.get(CONST.PARTNER_PREVIOUS_MESSAGE) + label = '\n_label_ ' + dict_message[CONST.MESSAGE_TEXT] + if prv_msg: + parlai_msg[CONST.MESSAGE_TEXT] = prv_msg[1][CONST.MESSAGE_TEXT] + label + else: + parlai_msg[CONST.MESSAGE_TEXT] = label + return parlai_msg From f3abd877dc2e71b04a26d3392bcdf04b6b118896 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Wed, 8 Sep 2021 10:47:46 -0400 Subject: [PATCH 2/5] change to mutator --- parlai/tasks/wizard_of_internet/agents.py | 49 +++++++++-------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/parlai/tasks/wizard_of_internet/agents.py b/parlai/tasks/wizard_of_internet/agents.py index 81007a561b1..8bc38825836 100644 --- a/parlai/tasks/wizard_of_internet/agents.py +++ b/parlai/tasks/wizard_of_internet/agents.py @@ -18,6 +18,7 @@ from parlai.utils.data import DatatypeHelper import parlai.utils.logging as logging import parlai.tasks.wizard_of_internet.constants as CONST +from parlai.core.mutators import register_mutator, MessageMutator from .build import build @@ -573,37 +574,25 @@ def _knowledge_piece(self): return CONST.SELECTED_DOCS_TITLES -class PredictKnowledgeGivenLabelTeacher(WizardOfInternetBaseTeacher): - def __init__(self, opt, shared=None): - super().__init__(opt, shared=shared) - self.id = 'PredictKnowledgeGivenLabelTeacher' +@register_mutator("add_checked_sentence") +class AddCheckedSentence(MessageMutator): + """ + Adds the checked sentences as the label, and the label to the end of text. + E.g. run with: parlai display_data -t wizard_of_internet -n 100 -dt valid --mutators flatten,add_checked_sentence + """ - def _teacher_action_type(self) -> str: - return CONST.ACTION_WIZARD_DOC_SELECTION + def message_mutation(self, message: Message) -> Message: + original_message = message.copy() + try: + text = message.pop('text') + label = message.pop('labels')[0] + checked_sentence = ' '.join(message.get(CONST.SELECTED_SENTENCES, '')) - def _knowledge_piece(self): - return CONST.SELECTED_SENTENCES + text += f'\n_label_ {label}' + message['text'] = text - def additional_message_content(self, parlai_message: Message, action: Dict): - for item_key in ( - CONST.SELECTED_DOCS, - CONST.SELECTED_DOCS_TITLES, - CONST.SELECTED_SENTENCES, - ): - parlai_message[item_key] = action[item_key] + message['labels'] = [checked_sentence] + except KeyError: + return original_message - def create_parlai_message(self, dict_message: Dict): - parlai_msg = Message( - { - CONST.SPEAKER_ID: dict_message[CONST.SPEAKER_ID], - # CONST.MESSAGE_TEXT: dict_message[CONST.MESSAGE_TEXT] + "\n _label_ " + - CONST.LABELS: [' '.join(dict_message[CONST.SELECTED_SENTENCES])], - } - ) - prv_msg = dict_message.get(CONST.PARTNER_PREVIOUS_MESSAGE) - label = '\n_label_ ' + dict_message[CONST.MESSAGE_TEXT] - if prv_msg: - parlai_msg[CONST.MESSAGE_TEXT] = prv_msg[1][CONST.MESSAGE_TEXT] + label - else: - parlai_msg[CONST.MESSAGE_TEXT] = label - return parlai_msg + return message From c3efe23228f737210347eef16ebae625684facbf Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Wed, 8 Sep 2021 11:48:59 -0400 Subject: [PATCH 3/5] lm teacher as well --- parlai/tasks/wizard_of_internet/agents.py | 44 ++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/parlai/tasks/wizard_of_internet/agents.py b/parlai/tasks/wizard_of_internet/agents.py index 8bc38825836..53d9b4dc72d 100644 --- a/parlai/tasks/wizard_of_internet/agents.py +++ b/parlai/tasks/wizard_of_internet/agents.py @@ -20,6 +20,7 @@ import parlai.tasks.wizard_of_internet.constants as CONST from parlai.core.mutators import register_mutator, MessageMutator +import random from .build import build @@ -574,11 +575,13 @@ def _knowledge_piece(self): return CONST.SELECTED_DOCS_TITLES -@register_mutator("add_checked_sentence") -class AddCheckedSentence(MessageMutator): +@register_mutator("checked_sentence_as_label") +class CheckedSentenceAsLabel(MessageMutator): """ - Adds the checked sentences as the label, and the label to the end of text. - E.g. run with: parlai display_data -t wizard_of_internet -n 100 -dt valid --mutators flatten,add_checked_sentence + Sets the checked sentences as the label, and the label to the end of text. + + E.g. run with: parlai display_data -t wizard_of_internet -n 100 -dt valid --mutators + flatten,checked_sentence_as_label """ def message_mutation(self, message: Message) -> Message: @@ -596,3 +599,36 @@ def message_mutation(self, message: Message) -> Message: return original_message return message + + +@register_mutator("checked_sentence_as_label_lm") +class CheckedSentenceAsLabelLm(MessageMutator): + """ + Sets the checked sentences as the label, and the label to the end of text. + Language modeling version where a random piece of the label is sampled in the input. + + E.g. run with: parlai display_data -t wizard_of_internet -n 100 -dt valid --mutators + flatten,checked_sentence_as_label_lm + """ + + def message_mutation(self, message: Message) -> Message: + original_message = message.copy() + try: + text = message.pop('text') + label = message.pop('labels')[0] + checked_sentence = ' '.join(message.get(CONST.SELECTED_SENTENCES, '')) + + ls = label.split(' ') + ind = random.randint(0, len(ls) - 1) + + label1 = ' '.join(ls[0:ind]) + label2 = ' '.join(ls[ind : len(ls)]) + + text += f'{label1}\n_label_ {label2}' + message['text'] = text + + message['labels'] = [checked_sentence] + except KeyError: + return original_message + + return message From 34f66435708cfab92ca187ca4a40a235f085930b Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Wed, 8 Sep 2021 11:50:58 -0400 Subject: [PATCH 4/5] lm teacher as well --- parlai/tasks/wizard_of_internet/agents.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/tasks/wizard_of_internet/agents.py b/parlai/tasks/wizard_of_internet/agents.py index 53d9b4dc72d..3692468617b 100644 --- a/parlai/tasks/wizard_of_internet/agents.py +++ b/parlai/tasks/wizard_of_internet/agents.py @@ -604,8 +604,8 @@ def message_mutation(self, message: Message) -> Message: @register_mutator("checked_sentence_as_label_lm") class CheckedSentenceAsLabelLm(MessageMutator): """ - Sets the checked sentences as the label, and the label to the end of text. - Language modeling version where a random piece of the label is sampled in the input. + Sets the checked sentences as the label, and the label to the end of text. Language + modeling version where a random piece of the label is sampled in the input. E.g. run with: parlai display_data -t wizard_of_internet -n 100 -dt valid --mutators flatten,checked_sentence_as_label_lm From 7fee28591d656122733f750f0b988f82ff723af0 Mon Sep 17 00:00:00 2001 From: Jason Weston Date: Wed, 8 Sep 2021 18:47:16 -0400 Subject: [PATCH 5/5] label name change --- parlai/tasks/wizard_of_internet/agents.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/tasks/wizard_of_internet/agents.py b/parlai/tasks/wizard_of_internet/agents.py index 3692468617b..74dc8509c13 100644 --- a/parlai/tasks/wizard_of_internet/agents.py +++ b/parlai/tasks/wizard_of_internet/agents.py @@ -591,7 +591,7 @@ def message_mutation(self, message: Message) -> Message: label = message.pop('labels')[0] checked_sentence = ' '.join(message.get(CONST.SELECTED_SENTENCES, '')) - text += f'\n_label_ {label}' + text += f'\n__label__ {label} __endlabel__' message['text'] = text message['labels'] = [checked_sentence] @@ -624,7 +624,7 @@ def message_mutation(self, message: Message) -> Message: label1 = ' '.join(ls[0:ind]) label2 = ' '.join(ls[ind : len(ls)]) - text += f'{label1}\n_label_ {label2}' + text += f'{label1}\n__label__ {label2} __endlabel__' message['text'] = text message['labels'] = [checked_sentence]