diff --git a/projects/README.md b/projects/README.md index e85f3bab9d0..d40737da8b0 100644 --- a/projects/README.md +++ b/projects/README.md @@ -113,7 +113,7 @@ _Task & models for chitchat with a given persona._ - **SeeKeR:** [[project]](http://parl.ai/projects/seeker) _Modular open source search-augmented language model._ -- **Reason first, then respond:** [[paper]](https://arxiv.org/abs/2111.05204) _A modular Generation method for Knowledge-infused Dialogue._ +- **Reason first, then respond:** [[project]](https://parl.ai/projects/k2r/) [[paper]](https://arxiv.org/abs/2111.05204) _A modular Generation method for Knowledge-infused Dialogue._ - **Internet-Augmented Dialogue Generation** [[project]](http://parl.ai/projects/sea) [[paper]](https://arxiv.org/abs/2107.07566). _Utilizing a search-engine for open domain chitchat task & models._ diff --git a/projects/k2r/README.md b/projects/k2r/README.md new file mode 100644 index 00000000000..e2936cf7485 --- /dev/null +++ b/projects/k2r/README.md @@ -0,0 +1,53 @@ +# Reason first, then respond: Modular Generation for Knowledge-infused Dialogue +Leonard Adolphs, Kurt Shuster, Jack Urbanek, Arthur Szlam, Jason Weston + +Paper Link: [https://arxiv.org/abs/2111.05204](https://arxiv.org/abs/2111.05204) + +## Abstract +Large language models can produce fluent dialogue but often hallucinate factual inaccuracies. While retrieval-augmented models help alleviate this issue, they still face a difficult challenge of both reasoning to provide correct knowledge and generating conversation simultaneously. In this work, we propose a modular model, Knowledge to Response (K2R), for incorporating knowledge into conversational agents, which breaks down this problem into two easier steps. K2R first generates a knowledge sequence, given a dialogue context, as an intermediate step. After this "reasoning step", the model then attends to its own generated knowledge sequence, as well as the dialogue context, to produce a final response. In detailed experiments, we find that such a model hallucinates less in knowledge-grounded dialogue tasks, and has advantages in terms +of interpretability and modularity. +In particular, it can be used to fuse QA and dialogue systems together to enable dialogue agents to give knowledgeable answers, or QA models to give conversational responses in a zero-shot setting. + + +## Train a shared K2R model on WoW +``` +parlai train \ + -t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_checked_sentence_as_label,projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_add_checked_sentence_to_input \ + --multitask_weights 1,1 --activation gelu --attention-dropout 0.0 --batchsize 16 --dropout 0.1 --fp16 True --gradient-clip 0.1 --label-truncate 128 \ + --text-truncate 512 --log-every-n-secs 30 --lr-scheduler reduceonplateau --lr-scheduler-patience 1 --max-train-time 169344.0 --model-parallel True \ + --model rag -o arch/bart_large --init-model zoo:bart/bart_large/model --dict-file zoo:bart/bart_large/model.dict --warmup-updates 0 \ + --multitask-weights stochastic --relu-dropout 0.0 --save-after-valid True --skip-generation True -lr 1e-05 -vmm min -veps 0.25 -vme 1000 \ + -vmt ppl -vp 5 --n-docs 5 -tblog True --indexer-type compressed --compressed-indexer-nprobe 128 \ + --model-file ./models/wow/k2r_shared +``` + +## Evaluate the model on WoW +``` +parlai em \ + -t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:random_split \ + -m projects.k2r.stacked_agent.task.agents:StackedKnowledgeDialogueAgent \ + --knowledge-response-model-path ./models/wow/k2r_shared \ + --dialogue-response-model-path ./models/wow/k2r_shared \ + --dialogue-response-no-knowledge-model-path None \ + --dialogue-response-rag-wiki-model-path None \ + --mutators flatten -dt valid --krm-fp16 False --krm-model-parallel False --drm-model-parallel False --krm-beam-min-length 15 \ + --krm-beam-size 3 --krm-indexer-type compressed --krm-compressed-indexer-nprobe 128 --krm-n-docs 5 --drm-beam-size 3 --drm-beam-min-length 20 --batchsize 2 --log-every-n-secs 30 --metrics all +``` + +## Do interactive generations with the model +``` +python projects/k2r/stacked_agent/scripts/stacked_agent_eval.py \ + --task wizard_of_wikipedia:Generator -dt test -bs 1 -n 100 \ + --interactive true --mutators flatten --random-order false --verbose true \ + --drm-beam-context-block-ngram 3 --beam-disregard-knowledge-for-context-blocking false \ + --knowledge-response-model-path ./models/wow/k2r_shared \ + --dialogue-response-model-path ./models/wow/k2r_shared +``` + +## LightQA data +Our goal with LightQA is to have a task that requires a model to answer questions *about the previous context*. For example, in LIGHT, a player might ask another character where to find a certain key to complete their quest. Here, we would want a model, acting as the character, to answer appropriately if the knowledge is in the context description. With this goal in mind, we design a dataset in the following way: First, we take a LightWild episode and use an abstractive summarization model, trained on CNN/Daily Mail and the SAMSum Corpus, to generate a summary. Then we identify all noun chunks, entities, and proper nouns and use them as possible answer candidates. For each answer candidate, we use a T5 question generation model, trained on SQuAD, to generate a possible question given the summary as context. As the last step, we filter the generated questions with a QA model, trained on SQuAD, by checking that it would generate the used answer candidate with access to the summary and question. An episode of our dataset consists of the original LightWild episode (up to a certain turn) and the generated question as the last utterance. Hence, our labels in this dataset are not the usual dialogue responses but short answers. +``` +# Display the data. +parlai dd -t projects.k2r.lightqa.task.agents -dt valid +``` + diff --git a/projects/k2r/lightqa/task/__init__.py b/projects/k2r/lightqa/task/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/projects/k2r/lightqa/task/agents.py b/projects/k2r/lightqa/task/agents.py new file mode 100644 index 00000000000..6854144a098 --- /dev/null +++ b/projects/k2r/lightqa/task/agents.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import json +import os + +from parlai.core.teachers import DialogTeacher +from parlai.utils.io import PathManager +from parlai.core.message import Message +from parlai.core.metrics import F1Metric, normalize_answer, AverageMetric + +from .build import build + + +class SummaryQATeacher(DialogTeacher): + """ + Teacher for the SummaryQA dataset. + """ + + def __init__(self, opt, shared=None): + self.datatype = opt['datatype'].split(':')[0] + build(opt) + opt['datafile'] = os.path.join( + opt['datapath'], f'lightqa/lightqa-wild-summaryqa2-{self.datatype}.json' + ) + self.id = 'summaryqa' + super().__init__(opt, shared) + + def setup_data(self, path): + print('loading: ' + path) + with PathManager.open(path) as data_file: + self.episodes = json.load(data_file) + for ex in self.episodes: + episode_done = ex.pop('episode_done') + yield ex, episode_done + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + if 'text' in model_response and model_response['text']: + normalized_response = normalize_answer(model_response['text']) + + if labels: + normalized_labels = [normalize_answer(a) for a in labels] + self.metrics.add( + 'norm_f1', + F1Metric.compute(normalized_response, normalized_labels), + ) + self.metrics.add( + 'norm_em', + AverageMetric(int(normalized_response in normalized_labels)), + ) + self.metrics.add( + 'kaa', + AverageMetric( + int(any([l in normalized_response for l in normalized_labels])) + ), + ) + + if 'knowledge_response' in model_response: + # Is the predicted knowledge response in the dialogue response? + self.metrics.add( + 'pkaa', + AverageMetric( + int( + normalize_answer(model_response['knowledge_response']) + in normalized_response + ) + ), + ) + + +class DefaultTeacher(SummaryQATeacher): + pass diff --git a/projects/k2r/lightqa/task/build.py b/projects/k2r/lightqa/task/build.py new file mode 100644 index 00000000000..95a8c9b6ae0 --- /dev/null +++ b/projects/k2r/lightqa/task/build.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Download and build the data if it does not exist. +""" + +from parlai.core.build_data import DownloadableFile +import parlai.core.build_data as build_data +import os +from shutil import copyfile + + +RESOURCES = [ + DownloadableFile( + 'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_train.json', + 'lightqa-wild-summaryqa2-train.json', + '0c618e0736317fbb9a688f82777165675b5967ffc5208041da940a3e3a947d25', + zipped=False, + ), + DownloadableFile( + 'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_valid.json', + 'lightqa-wild-summaryqa2-valid.json', + '3646ff1e6549ec82588caaf7da998ef18df629cacdde43d8ce813df545aabe6c', + zipped=False, + ), + DownloadableFile( + 'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_test.json', + 'lightqa-wild-summaryqa2-test.json', + '70804bd77fe7568326a1e229b3ece578cd1867c3e0e8a14fef23faf4e2032f14', + zipped=False, + ), +] + + +def build(opt): + version = 'v1.0.0' + dpath = os.path.join(opt['datapath'], 'lightqa') + + if not build_data.built(dpath, version): + print('[building data: ' + dpath + ']') + if build_data.built(dpath): + # An older version exists, so remove these outdated files. + build_data.remove_dir(dpath) + build_data.make_dir(dpath) + + # Download the data. + for downloadable_file in RESOURCES: + if downloadable_file.url.startswith('/checkpoint'): + copyfile( + downloadable_file.url, + os.path.join(dpath, downloadable_file.file_name), + ) + else: + downloadable_file.download_file(dpath) + + # Mark the data as built. + build_data.mark_done(dpath, version) diff --git a/projects/k2r/stacked_agent/scripts/stacked_agent_eval.py b/projects/k2r/stacked_agent/scripts/stacked_agent_eval.py new file mode 100644 index 00000000000..add06c2e3d6 --- /dev/null +++ b/projects/k2r/stacked_agent/scripts/stacked_agent_eval.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Evaluate the stacked models. +""" + +from projects.k2r.stacked_agent.task.agents import ( + StackedKnowledgeDialogueAgent, +) +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript +from parlai.utils.strings import colorize +from parlai.core.worlds import _create_task_agents + +import random +import json +from copy import deepcopy + +from parlai.tasks.wizard_of_wikipedia.agents import ( + TOKEN_KNOWLEDGE, + TOKEN_END_KNOWLEDGE, +) + + +def setup_args(parser=None): + if parser is None: + parser = ParlaiParser(True, True, 'Print out examples for merged model.') + parser.add_argument('-n', '--num-examples', type=int, default=5) + parser.add_argument( + '--interactive', + type=bool, + default=True, + help='Interactively choose the knowledge response.', + ) + parser.add_argument( + '--save-file', + type=str, + default='', + help='Save the responses as json file.', + ) + parser.add_argument( + '--random-order', + type=bool, + default=True, + help='Go through the examples in random order.', + ) + parser.add_argument( + '--verbose', + type=bool, + default=True, + help='Print the examples.', + ) + StackedKnowledgeDialogueAgent.add_cmdline_args(parser) + return parser + + +def model_output(opt): + # Init teacher and agent. + teacher = _create_task_agents(opt)[0] + stacked_agent = StackedKnowledgeDialogueAgent(opt) + + result = [] + + entry_idx = 0 + episode_idx = 0 + seen_episodes = 0 + blocked_episode_prefixes = [] + while seen_episodes < opt['num_examples']: + + if opt['random_order']: + # Hack to go through examples more randomly. + for _ in range(random.randint(10, 30)): + while not teacher.act()['episode_done']: + pass + + episode_done = False + while not episode_done: + # Interaction between teacher and agent. + query = teacher.act() + if blocked_episode_prefixes and any( + [query['text'].startswith(pref) for pref in blocked_episode_prefixes] + ): + continue + stacked_agent.observe(query) + reply = stacked_agent.act() + if episode_idx != teacher.episode_idx: + entry_idx = 0 + seen_episodes += 1 + episode_idx = teacher.episode_idx + if episode_idx == 1 or episode_idx % 5 == 0: + print(f'Episode {episode_idx}/{opt["num_examples"]}') + episode_done = query['episode_done'] + + # Get the gold labels for both the knowledge and dialogue response. + label_key = 'eval_labels' if 'eval_labels' in query else 'labels' + + knowledge_response_target_kwords = [ + 'knowledge_target', + 'knowledge_answer', + 'checked_sentence', + '__selected-sentences__', + ] + target_knowledge_response = '' + for kword in knowledge_response_target_kwords: + if kword in query: + target_knowledge_response = query[kword] + break + if 'SummaryQATeacher' in opt['task'] and not target_knowledge_response: + target_knowledge_response = query['eval_labels'][0] + if isinstance(target_knowledge_response, list): + target_knowledge_response = ' '.join(target_knowledge_response) + + target_dialogue_response = reply.get(label_key, '') + if not target_dialogue_response: + target_dialogue_response = query.get('dialogue_response', '') + if isinstance(target_dialogue_response, list): + target_dialogue_response = ' '.join(target_dialogue_response) + if ( + ( + 'lightqa_labeltype' in opt + and opt['lightqa_labeltype'] == 'dialogue_response' + and 'eval_labels' in query + and query['eval_labels'] + ) + or 'wizard_of_wikipedia' in opt['task'] + or 'light_dialog_wild' in opt['task'] + ): + target_dialogue_response = query['eval_labels'][0] + + knowledge_response = reply.get('knowledge_response', '') + dialogue_response = reply.get('text', '') + + result.append( + dict( + episode_idx=episode_idx, + entry_idx=entry_idx, + context=query['text'], + knowledge_response=knowledge_response, + target_knowledge_response=target_knowledge_response, + dialogue_response=dialogue_response, + target_dialogue_response=target_dialogue_response, + ) + ) + entry_idx += 1 + + if opt['verbose']: + # Print the history and predicted knowledge. + print('\n', query['text']) + print( + ' Knowledge Prediction: ' + + colorize(knowledge_response, 'green') + + ' Gold: ' + + colorize(target_knowledge_response, 'yellow') + ) + if 'support_sentence' in reply: + print( + ' Support Sentence: ' + + colorize(reply.get('support_sentence', ''), 'green') + ) + + if opt['interactive']: + cont_choice = 'r' + while cont_choice == 'r': + # Let the user choose the conditioning knowledge. + user_input = input( + colorize( + 'Type in knowledge to condition generation: ', 'red' + ) + ) + + knowledge_infused_observation = deepcopy( + stacked_agent.observations['raw'] + ) + text = knowledge_infused_observation.pop('text') + text += ( + f'\n{TOKEN_KNOWLEDGE} {user_input} {TOKEN_END_KNOWLEDGE}' + ) + knowledge_infused_observation['text'] = text + dialogue_response_user_knowledge = stacked_agent.dialogue_reply( + agent=stacked_agent.dialogue_agent, + observation=knowledge_infused_observation, + ) + dialogue_response_user_knowledge[ + 'knowledge_response' + ] = user_input + stacked_agent._filter_beams( + reply=dialogue_response_user_knowledge, + filter_for_knowledge=stacked_agent.opts['init'][ + 'beam_filter_for_knowledge_response' + ], + filter_questions=stacked_agent.opts['init'][ + 'beam_filter_questions' + ], + filter_self_references=stacked_agent.opts['init'][ + 'beam_filter_self_references' + ], + ) + print( + ' Dialogue Prediction (user knowledge): ' + + colorize( + dialogue_response_user_knowledge['text'], 'green' + ) + ) + cont_choice = input( + colorize('Continuation choice (c/d/n/r): ', 'red') + ) + if cont_choice == 'd': + # Debug. + import pdb + + pdb.set_trace() + cont_choice = 'c' + elif cont_choice == 'n': + # Skip the episode. + blocked_episode_prefixes = [ + stacked_agent.observations['raw']['text'] + ] + cont_choice = 'c' + elif cont_choice == 'r' or cont_choice == 'c': + # Continue or retry the episode. + pass + else: + print( + f"Can't parse continuation choice '{cont_choice}'. Continue." + ) + cont_choice = 'c' + + # Print the remaining predictions. + if ( + 'text_no_knowledge' in reply + and reply['text_no_knowledge'] != 'None' + ): + print( + ' Dialogue Prediction (no knowledge): ' + + colorize(reply.get('text_no_knowledge', ''), 'green') + ) + if 'text_rag_wiki' in reply and reply['text_rag_wiki'] != 'None': + print( + ' Dialogue Prediction (rag wiki): ' + + colorize(reply.get('text_rag_wiki', ''), 'green') + ) + print( + ' Dialogue Prediction (predicted knowledge): ' + + colorize(dialogue_response, 'green') + ) + if ( + 'text_knowledge_sentence' in reply + and reply['text_knowledge_sentence'] != 'None' + ): + print( + ' Dialogue Prediction (predicted knowledge sentence): ' + + colorize(reply.get('text_knowledge_sentence', ''), 'green') + ) + + if target_dialogue_response: + print( + ' Gold Dialogue Response: ', + colorize(target_dialogue_response, 'yellow'), + ) + + if opt['save_file']: + with open(opt['save_file'], 'w') as f: + json.dump(result, f) + print(f'Saved results to {opt["save_file"]}') + + +class StackedAgentOutput(ParlaiScript): + @classmethod + def setup_args(cls): + return setup_args() + + def run(self): + return model_output(self.opt) + + +if __name__ == '__main__': + StackedAgentOutput.main() diff --git a/projects/k2r/stacked_agent/task/agents.py b/projects/k2r/stacked_agent/task/agents.py new file mode 100644 index 00000000000..90ef0070bfd --- /dev/null +++ b/projects/k2r/stacked_agent/task/agents.py @@ -0,0 +1,757 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from copy import deepcopy +from typing import Optional, List +from types import MethodType +from collections import defaultdict +from nltk.tokenize import sent_tokenize +from nltk.corpus import stopwords +import random +import spacy +import torch + +from parlai.core.agents import create_agent, create_agent_from_shared +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.core.message import Message +from parlai.core.agents import Agent +from parlai.core.metrics import F1Metric +from parlai.agents.rag.args import setup_rag_args +from parlai.agents.bart.bart import BartAgent +from parlai.core.metrics import normalize_answer +from parlai.tasks.wizard_of_wikipedia.agents import ( + TOKEN_NOCHOSEN, +) + +from parlai.tasks.wizard_of_wikipedia.agents import ( + TOKEN_KNOWLEDGE, + TOKEN_END_KNOWLEDGE, +) + +STOP_WORDS = stopwords.words('english') + + +def load_opt_from_file(opt_file): + if not opt_file.endswith('.opt'): + opt_file += '.opt' + return Opt.load(opt_file) + + +def wow_get_batch_context(self, batch, orig_fun=None): + """ + Set the beam context for n-gram context blocking specific for WoW data. + + For WoW, we don't want to consider the knowledge in the input for the context beam + blocking. That's why we mask it out here. + """ + ctxts = orig_fun(batch) + knowledge_start_id = self.dict.txt2vec(TOKEN_KNOWLEDGE) + knowledge_end_id = self.dict.txt2vec(TOKEN_END_KNOWLEDGE) + + def mask_ctxttensor_between_sublists( + ctxts: torch.Tensor, sub1: List[int], sub2: List[int] + ) -> torch.Tensor: + """ + Generate a mask that masks out the context between sub1 and sub2. + """ + mask = [] + for ctxt in ctxts: + mask_idxs = [] + should_copy = False + idx_pointer = 0 + id_to_match = sub1 + for j, token in enumerate(ctxt.cpu().numpy()): + if token == id_to_match[idx_pointer]: + idx_pointer += 1 + if idx_pointer == 1 and id_to_match == sub1: + mask_idxs.append([j]) + elif idx_pointer >= len(id_to_match): + should_copy = id_to_match == sub1 + idx_pointer = 0 + id_to_match = sub2 if (id_to_match == sub1) else sub1 + mask_idxs[-1].append(j) + else: + mask_idxs[-1].append(j) + elif should_copy: + mask_idxs[-1].append(j) + elif idx_pointer > 0: + idx_pointer = 0 + del mask_idxs[-1] + mask.append( + [ + 0 if idx in [i for sl in mask_idxs for i in sl] else 1 + for idx in range(len(ctxt)) + ] + ) + return torch.LongTensor(mask).to(ctxts.device) + + ctxts *= mask_ctxttensor_between_sublists( + ctxts, knowledge_start_id, knowledge_end_id + ) + return ctxts + + +def find_supporting_sentence(question: str, answer: str, docs: List[str]) -> str: + """ + Finds the supporting sentence for the answer in the docs. + """ + # Remove the title of the documents. + for i, doc in enumerate(docs): + if ' | ' in doc: + docs[i] = '. '.join(doc.split(' | ')[1:]) + concat_docs = '. '.join(docs) + sentences = sent_tokenize(concat_docs) + + # Sort sentences according to recall with the answer and question. + sorted_sentences = sorted( + sentences, + key=lambda sentence: ( + F1Metric._prec_recall_f1_score( + normalize_answer(answer).split(), normalize_answer(sentence).split() + )[0], + F1Metric._prec_recall_f1_score( + normalize_answer(question).split(), normalize_answer(sentence).split() + )[0], + ), + reverse=True, + ) + + return sorted_sentences[0] + + +def extract_entities( + sentence, pos=('PROPN', 'NOUN'), use_named_entities=True, use_noun_chunks=True +): + global nlp + if nlp is None: + nlp = spacy.load("en_core_web_sm") + doc = nlp(sentence) + results = [] + if pos: + for token in doc: + if token.pos_ in pos: + results.append(token) + if use_named_entities: + for ent in doc.ents: + results.append(ent) + if use_noun_chunks: + for chunk in doc.noun_chunks: + if chunk.text.lower() not in STOP_WORDS: + results.append(chunk) + results = list(set([r.text for r in results])) + return results + + +def extract_knowledge(txt: str) -> List[str]: + if not txt or not txt.split(): + return [] + entities = extract_entities(txt) + return [e.lower() for e in (entities if entities else txt.split())] + + +def knowledge_from_dialogue_response(dialogue_response: str) -> str: + """ + Get a knowledge response based on the dialogue response. + + We use a random entity from the dialogue response as knowledge. If there is no + entity present, we use a random word. If there are no words present, we use + TOKEN_NOCHOSEN. + """ + knowledge_options = extract_knowledge(dialogue_response) + if not knowledge_options: + return TOKEN_NOCHOSEN + return random.choice(knowledge_options) + + +class StackedKnowledgeDialogueAgent(Agent): + """ + Stacked model that generates first the knowledge, and then the dialogue response. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + agent = parser.add_argument_group('StackedKnowledgeDialogueAgent options') + + additional_agent_parser = ParlaiParser(add_parlai_args=False) + BartAgent.add_cmdline_args(additional_agent_parser) + setup_rag_args(additional_agent_parser) + + for action in additional_agent_parser._actions: + key = action.option_strings[-1] + type = action.type + + for prefix in [ + 'krm', # knowledge response model + 'drm', # dialogue response model + 'drmrw', # dialogue response model rag-wiki + 'drmnk', # dialogue response model no-knowledge + ]: + agent.add_argument( + f'--{prefix}-{key.strip("-")}', + type=type, + required=False, + ) + + agent.add_argument( + '--knowledge-response-model-path', + type=str, + default='' + 'wow_knowledge_response_generation_sweep_1_Tue_Aug_17_1822/9aa/model', + help='Model used to generate the knowledge response.', + ) + agent.add_argument( + '--dialogue-response-model-path', + type=str, + default='', + help='Model used to generate the dialogue response.', + ) + agent.add_argument( + '--dialogue-response-no-knowledge-model-path', + type=str, + default='', + help='Model used to generate the dialogue response without knowledge.', + ) + agent.add_argument( + '--dialogue-response-rag-wiki-model-path', + type=str, + default='', + help='Model used to generate the dialogue response with Wiki knowledge.', + ) + agent.add_argument( + '--use-supporting-sentence-as-knowledge', + type=bool, + default=False, + help='Instead of using the knowledge response directly to condition the dialogue' + ' model, we search for the top supporting sentence and use this.', + ) + agent.add_argument( + '--beam-filter-for-knowledge-response', + type=bool, + default=False, + help='Try to pick a beam that contains the knowledge response.', + ) + agent.add_argument( + '--beam-filter-questions', + type=bool, + default=False, + help='Try to pick a beam that does not contain a question mark.', + ) + agent.add_argument( + '--beam-filter-self-references', + type=bool, + default=False, + help='Try to pick a beam that does not contain self references like "I" and "me".', + ) + agent.add_argument( + '--beam-disregard-knowledge-for-context-blocking', + type=bool, + default=True, + help='If True disregard the knowledge input for the context blocking.', + ) + agent.add_argument( + '--add-fixed-confidence', + type=int, + default=-1, + help='Add a fixed confidence score of the knowledge response.', + ) + agent.add_argument( + '--add-confidence-as-str', + type=bool, + default=False, + help='If we add a confidence score to the KRM, we add it as a str.', + ) + return parser + + def __init__(self, opt, shared=None): + self.id = 'StackedKnowledgeDialogueAgent' + self._construct_opts(opt) + + self.knowledge_agent = None + self.dialogue_agent = None + self.dialogue_agent_no_knowledge = None + self.dialogue_agent_rag_wiki = None + + if not shared: + self._init_knowledge_model() + self._init_dialogue_models() + else: + if 'knowledge_agent_share' in shared: + self.knowledge_agent = create_agent_from_shared( + shared['knowledge_agent_share'] + ) + if 'dialogue_agent_share' in shared: + self.dialogue_agent = create_agent_from_shared( + shared['dialogue_agent_share'] + ) + if 'dialogue_agent_no_knowledge_share' in shared: + self.dialogue_agent = create_agent_from_shared( + shared['dialogue_agent_no_knowledge_share'] + ) + if 'dialogue_agent_rag_wiki_share' in shared: + self.dialogue_agent = create_agent_from_shared( + shared['dialogue_agent_rag_wiki_share'] + ) + + self.shared = shared + super().__init__(opt, shared) + + @property + def has_no_knowledge_dialogue_model(self): + return ( + self.opts['init']['dialogue_response_no_knowledge_model_path'] + and self.opts['init']['dialogue_response_no_knowledge_model_path'] != 'None' + ) + + @property + def has_rag_wiki_dialogue_model(self): + return ( + self.opts['init']['dialogue_response_rag_wiki_model_path'] + and self.opts['init']['dialogue_response_rag_wiki_model_path'] != 'None' + ) + + def _agent_opt( + self, filename, specific_override_args, general_override_args, **kwargs + ): + opt = load_opt_from_file(filename) + opt['override'] = {} + blocklist_general = ['model', 'model_file', 'init_model'] + standard_override_args = { + 'skip_generation': False, + 'inference': 'beam', + 'beam_block_ngram': 3, + 'beam_context_block_ngram': -1, + 'beam_size': 3, + } + general_override_args = { + **general_override_args, + **standard_override_args, + **kwargs, + } + + # Remove the prefix for the model for the specific override args. + specific_override_args = { + '_'.join(k.split('_')[1:]): v for k, v in specific_override_args.items() + } + + # Specific for --indexer-type. + if 'indexer_type' in specific_override_args: + # TODO: do we also need to overwrite path_to_index? + pass + + override_args = {**general_override_args, **specific_override_args} + + for k, v in override_args.items(): + if k not in blocklist_general and k in opt: + opt['override'][k] = v + + return opt + + def _construct_opts(self, opt): + self.opts = {} + self.opts['init'] = opt + override_opts = defaultdict(dict) + for k, v in opt['override'].items(): + if k.startswith('krm_'): + if v is not None: + override_opts['knowledge_agent'][k] = v + elif k.startswith('drm_'): + if v is not None: + override_opts['dialogue_agent'][k] = v + elif k.startswith('drmnk_'): + if v is not None: + override_opts['dialogue_agent_no_knowledge'][k] = v + elif k.startswith('drmrw_'): + if v is not None: + override_opts['dialogue_agent_rag_wiki'][k] = v + else: + override_opts['general'][k] = v + self.opts['override'] = override_opts + + if opt['knowledge_response_model_path'] and opt[ + 'knowledge_response_model_path' + ] not in ['oracle']: + self.opts['knowledge_agent'] = self._agent_opt( + filename=opt['knowledge_response_model_path'], + specific_override_args=override_opts['knowledge_agent'], + general_override_args=override_opts['general'], + ) + self.opts['dialogue_agent'] = self._agent_opt( + filename=opt['dialogue_response_model_path'], + specific_override_args=override_opts['dialogue_agent'], + general_override_args=override_opts['general'], + ) + if self.has_no_knowledge_dialogue_model: + self.opts['dialogue_agent_no_knowledge'] = self._agent_opt( + filename=opt['dialogue_response_no_knowledge_model_path'], + specific_override_args=override_opts['dialogue_agent_no_knowledge'], + general_override_args=override_opts['general'], + ) + if self.has_rag_wiki_dialogue_model: + self.opts['dialogue_agent_rag_wiki'] = self._agent_opt( + filename=opt['dialogue_response_rag_wiki_model_path'], + specific_override_args=override_opts['dialogue_agent_rag_wiki'], + general_override_args=override_opts['general'], + ) + + def share(self): + shared = super().share() + shared['knowledge_agent_share'] = self.knowledge_agent.share() + shared['dialogue_agent_share'] = self.dialogue_agent.share() + if self.has_no_knowledge_dialogue_model: + shared[ + 'dialogue_agent_no_knowledge_share' + ] = self.dialogue_agent_no_knowledge.share() + if self.has_rag_wiki_dialogue_model: + shared[ + 'dialogue_agent_rag_wiki_share' + ] = self.dialogue_agent_rag_wiki.share() + return shared + + def _init_knowledge_model(self): + # Initialize knowledge agent. + if 'knowledge_agent' in self.opts: + self.knowledge_agent = create_agent( + self.opts['knowledge_agent'], requireModelExists=True + ) + print('Options for Knowledge Response Agent') + self.knowledge_agent.opt.log() + elif self.opts['init']['knowledge_response_model_path'] == 'oracle': + self.knowledge_agent = OracleKnowledgeAgent(self.opts['init']) + + def _init_dialogue_models(self): + ## Init dialogue models. + + # Initialize dialogue agent that uses the predicted knowledge. + self.dialogue_agent = create_agent( + self.opts['dialogue_agent'], requireModelExists=True + ) + # Monkey patch the get_batch_context to ignore the knowledge for + # beam-context-blocking. + if self.opts['init']['beam_disregard_knowledge_for_context_blocking']: + orig_fun = self.dialogue_agent._get_batch_context + self.dialogue_agent._get_batch_context = MethodType( + lambda self, batch: wow_get_batch_context( + self, batch, orig_fun=orig_fun + ), + self.dialogue_agent, + ) + + print('Options for Dialogue Response Agent') + self.dialogue_agent.opt.log() + + # Initialize dialogue agent that doesn't use knowledge. + if self.has_no_knowledge_dialogue_model: + self.dialogue_agent_no_knowledge = create_agent( + self.opts['dialogue_agent_no_knowledge'], requireModelExists=True + ) + + # Initialize dialogue agent that uses RAG with Wiki. + if self.has_rag_wiki_dialogue_model: + self.dialogue_agent_rag_wiki = create_agent( + self.opts['dialogue_agent_rag_wiki'], requireModelExists=True + ) + + def dialogue_reply(self, agent, observation): + return self.batch_dialogue_reply(agent, [observation])[0] + + def batch_dialogue_reply(self, agent, observations): + dialogue_observations = [] + # Observation for the dialogue model. + for obs in observations: + dialogue_observation = agent.observe(obs) + agent.self_observe(dialogue_observation) + dialogue_observations.append(dialogue_observation) + + return agent.batch_act(dialogue_observations) + + def generate_knowledge_observation(self, knowledges: List[str], observations): + # Adjust the observation texts. + knowledge_infused_observations = deepcopy(observations) + for obs, knowledge in zip(knowledge_infused_observations, knowledges): + if 'text' not in obs: + continue + text = obs.pop('text') + + if self.opts['init']['add_fixed_confidence'] >= 0: + confidence = self.opts['init']['add_fixed_confidence'] + if self.opts['init']['add_confidence_as_str']: + confidence = { + 0: 'low', + 5: 'medium', + 10: 'high', + }[confidence] + ' confidence' + text += f'\n{TOKEN_KNOWLEDGE} {confidence}: {knowledge} {TOKEN_END_KNOWLEDGE}' + else: + text += f'\n{TOKEN_KNOWLEDGE} {knowledge} {TOKEN_END_KNOWLEDGE}' + obs['text'] = text + return knowledge_infused_observations + + def batch_act(self, observations): + knowledge_agent_observations = [o['knowledge_agent'] for o in observations] + raw_observations = [o['raw'] for o in observations] + + # Get the knowledge replies. + if self.knowledge_agent is None: + self._init_knowledge_model() + batch_reply_knowledge = self.knowledge_agent.batch_act( + knowledge_agent_observations + ) + + if ( + 'top_docs' in batch_reply_knowledge[0] + and self.opts['init']['use_supporting_sentence_as_knowledge'] + ): + # The knowledge agent is a rag-style model. Instead of the actual knowledge + # response, we will use the best matching sentence from the retrieved docs + # as the knowledge conditioning. + for i, reply_knowledge in enumerate(batch_reply_knowledge): + reply_knowledge['support_sentence'] = find_supporting_sentence( + question=raw_observations[i]['text'], + answer=reply_knowledge['text'], + docs=reply_knowledge['top_docs'], + ) + + if self.dialogue_agent is None: + self._init_dialogue_models() + + if self.dialogue_agent_no_knowledge: + batch_reply_dialogue_no_knowledge = self.batch_dialogue_reply( + self.dialogue_agent_no_knowledge, raw_observations + ) + if self.dialogue_agent_rag_wiki: + batch_reply_dialogue_rag_wiki = self.batch_dialogue_reply( + self.dialogue_agent_rag_wiki, raw_observations + ) + + knowledge_infused_observations = self.generate_knowledge_observation( + knowledges=[ + reply_knowledge.get('text', '') + for reply_knowledge in batch_reply_knowledge + ], + observations=raw_observations, + ) + batch_reply_dialogue = self.batch_dialogue_reply( + self.dialogue_agent, knowledge_infused_observations + ) + + batch_reply_dialogue_knowledge_sentence = None + if ( + self.opts['init']['use_supporting_sentence_as_knowledge'] + and 'support_sentence' in batch_reply_knowledge[0] + ): + knowledge_sentence_infused_observations = ( + self.generate_knowledge_observation( + knowledges=[ + reply_knowledge.get('support_sentence', '') + for reply_knowledge in batch_reply_knowledge + ], + observations=raw_observations, + ) + ) + batch_reply_dialogue_knowledge_sentence = self.batch_dialogue_reply( + self.dialogue_agent, knowledge_sentence_infused_observations + ) + + for i in range(len(batch_reply_dialogue)): + if batch_reply_knowledge and len(batch_reply_knowledge) > i: + batch_reply_dialogue[i]['knowledge_response'] = batch_reply_knowledge[ + i + ].get('text', '') + if 'support_sentence' in batch_reply_knowledge[i]: + batch_reply_dialogue[i]['support_sentence'] = batch_reply_knowledge[ + i + ].get('support_sentence', '') + if ( + self.dialogue_agent_no_knowledge + and batch_reply_dialogue_no_knowledge + and len(batch_reply_dialogue_no_knowledge) > i + ): + batch_reply_dialogue[i][ + 'text_no_knowledge' + ] = batch_reply_dialogue_no_knowledge[i].get('text', '') + + if ( + self.dialogue_agent_rag_wiki + and batch_reply_dialogue_rag_wiki + and len(batch_reply_dialogue_rag_wiki) > i + ): + batch_reply_dialogue[i][ + 'text_rag_wiki' + ] = batch_reply_dialogue_rag_wiki[i].get('text', '') + + if ( + batch_reply_dialogue_knowledge_sentence + and len(batch_reply_dialogue_knowledge_sentence) > i + ): + batch_reply_dialogue[i][ + 'text_knowledge_sentence' + ] = batch_reply_dialogue_knowledge_sentence[i].get('text', '') + + [ + self._filter_beams( + reply=reply, + filter_for_knowledge=self.opts['init'][ + 'beam_filter_for_knowledge_response' + ], + filter_questions=self.opts['init']['beam_filter_questions'], + filter_self_references=self.opts['init']['beam_filter_self_references'], + ) + for reply in batch_reply_dialogue + ] + + return batch_reply_dialogue + + def _filter_beams( + self, + reply, + filter_for_knowledge: bool = True, + filter_questions: bool = False, + filter_self_references: bool = False, + ): + knowledge = normalize_answer(reply['knowledge_response']) + self_references = [ + 'I live', + 'I love', + ' me ', + 'my favorite', + 'My favorite', + 'do you know', + 'I have', + 'I like ', + 'My ', + ] + question_words = [ + 'who', + 'when', + 'where', + 'what', + 'do you', + 'are you', + ] + + def filter_fn(text: str) -> bool: + normalized_text = normalize_answer(text) + if filter_for_knowledge and knowledge not in normalized_text: + return False + if filter_questions and ( + '?' in text or any([qw in normalized_text for qw in question_words]) + ): + return False + if filter_self_references and any([ref in text for ref in self_references]): + return False + return True + + if not ( + 'text' in reply + and 'beam_texts' in reply + and 'knowledge_response' in reply + and len(reply['beam_texts']) > 1 + ): + return + + beam_texts = [ + text + for text, _ in sorted(reply['beam_texts'], key=lambda x: x[1], reverse=True) + ] + # print('\t' + '\n\t'.join(beam_texts[:10])) + for text in [reply['text']] + beam_texts: + if filter_fn(text): + reply.force_set('text', text) + return + + def self_observe(self, self_message: Message) -> None: + # Hack: Feed back the final dialogue response to the knowledge model. + # This is why we need to make sure that --mutators flatten. + self.knowledge_agent.self_observe(self_message) + + def observe(self, observation): + # Delete unused keys. + for key in ['label_candidates', 'knowledge']: + if key in observation: + del observation[key] + + label_key = 'eval_labels' if 'eval_labels' in observation else 'labels' + if 'nqopen' in self.opts['init']['task'].lower(): + knowledge_target = observation.get(label_key, '') + if isinstance(knowledge_target, tuple): + knowledge_target = '\t'.join(knowledge_target) + observation['knowledge_target'] = knowledge_target + observation['dialogue_response'] = '' + else: + observation['dialogue_response'] = observation.get(label_key, '') + observation['knowledge_response'] = observation.get('checked_sentence', '') + + if self.knowledge_agent is None: + self._init_knowledge_model() + + observations = { + 'raw': deepcopy(observation), + 'knowledge_agent': self.knowledge_agent.observe(observation), + } + self.observations = observations + return observations + + def act(self): + """ + Call batch_act with the singleton batch. + """ + response = self.batch_act([self.observations])[0] + self.self_observe(response) + return response + + +class OracleKnowledgeAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = 'OracleKnowledgeAgent' + + def get_knowledge(self, obs): + labels_kword = 'labels' if 'train' in self.opt['datatype'] else 'eval_labels' + if 'wizardofwikipedia' in self.opt['task'].lower().replace('_', ''): + return obs.get('checked_sentence', '') + if 'wizardofinternet' in self.opt['task'].lower().replace('_', ''): + knowledge = obs.get('__selected-sentences__', '') + if isinstance(knowledge, list): + knowledge = '\n'.join(knowledge) + return knowledge + elif ( + 'nqopen' in self.opt['task'].lower() + or 'natural_questions' in self.opt['task'].lower() + ): + return obs.get(labels_kword, '') + elif 'LightTeacherPlus' in self.opt['task']: + if labels_kword not in obs or not obs[labels_kword]: + return '' + labels = obs[labels_kword] + if not isinstance(labels, str): + labels = labels[0] + return knowledge_from_dialogue_response(labels) + elif 'SummaryQA' in self.opt['task']: + return obs.get(labels_kword, '') + else: + raise NotImplementedError(f'Task "{self.opt["task"]}" is not known.') + + def self_observe(self, obs): + pass + + def batch_act(self, observations): + return [self.act(obs) for obs in observations] + + def act(self, obs=None): + if not obs: + obs = self.observation + if obs is None: + return {'text': 'Nothing to repeat yet.'} + reply = {} + reply['id'] = self.getID() + knowledge = self.get_knowledge(obs) + if not isinstance(knowledge, str): + knowledge = random.choice(knowledge) + reply['text'] = knowledge + return Message(reply) diff --git a/projects/k2r/wow/task/agents.py b/projects/k2r/wow/task/agents.py new file mode 100644 index 00000000000..c35a346ff2c --- /dev/null +++ b/projects/k2r/wow/task/agents.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import os +from typing import Optional, Tuple +from tqdm import tqdm +from collections import defaultdict +import json + +from parlai.core.message import Message +from parlai.core.metrics import F1Metric +from parlai.core.mutators import register_mutator, MessageMutator + +from parlai.tasks.wizard_of_wikipedia.agents import GeneratorTeacher +from parlai.tasks.wizard_of_wikipedia.agents import ( + TOKEN_KNOWLEDGE, + TOKEN_END_KNOWLEDGE, +) + + +class WizardOfWikipediaGeneratorTeacher(GeneratorTeacher): + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + + # Check if probabilistic checked sentence to input mutator is used. + if ( + opt['mutators'] + and 'add_probabilistic_checked_sentence_to_input_wow' in opt['mutators'] + ): + distractor_knowledge_fname = 'data/k2r/wow/distractor.txt' + if not os.path.exists(distractor_knowledge_fname): + # Collect all knowledge sentences. + checked_sentences = defaultdict(set) + for episode_idx in tqdm( + range(self.num_episodes()), + desc='Loading distractor knowledge', + ): + entry_idx = 0 + while True: + entry = self.get(episode_idx, entry_idx) + entry_idx += 1 + checked_sentence = entry.get('checked_sentence', '') + chosen_topic = entry.get('chosen_topic', '') + if not checked_sentence or not chosen_topic: + continue + checked_sentences[chosen_topic].add(checked_sentence) + if entry['episode_done']: + break + # Save it to file. + checked_sentences = {k: list(v) for k, v in checked_sentences.items()} + with open(distractor_knowledge_fname, 'w') as f: + json.dump(checked_sentences, f) + print( + f'Saved distractor knowledge sentences to "{distractor_knowledge_fname}".' + ) + + # Add file path to opt. + self.opt['distractor_knowledge_fname'] = distractor_knowledge_fname + + # Update mutator. + self.mutators = [ + AddCheckedSentence(self.opt) + if isinstance(mutator, AddCheckedSentence) + else mutator + for mutator in self.mutators + ] + + def getID(self): + name = super().getID() + if 'mutators' in self.opt and self.opt['mutators']: + return name + '__' + self.opt['mutators'] + return name + + def custom_evaluation( + self, + teacher_action: Message, + labels: Optional[Tuple[str]], + model_response: Message, + ): + super().custom_evaluation(teacher_action, labels, model_response) + if 'knowledge_response' in model_response: + self.metrics.add( + 'predicted_knowledge_f1', + F1Metric.compute( + model_response['knowledge_response'], + [model_response['text']], + ), + ) + self.metrics.add( + 'knowledge_response_f1', + F1Metric.compute( + model_response['knowledge_response'], + [teacher_action['checked_sentence']], + ), + ) + + +@register_mutator("add_probabilistic_checked_sentence_to_input_wow") +class AddCheckedSentence(MessageMutator): + """ + Adds the checked sentence to the end of the text. + + But with probability p, it picks a wrong one. It adds the round(p*10) to the input + as conditioning. + """ + + def __init__(self, opt): + super().__init__(opt) + if 'distractor_knowledge_fname' in opt: + # Load distractor checked sentences. + with open(opt['distractor_knowledge_fname'], 'r') as f: + self.distractor_knowledge_sentences = json.load(f) + + @property + def checked_sentence_kword(self): + return 'checked_sentence' + + def message_mutation(self, message: Message) -> Message: + new_message = message.copy() + if 'text' not in message: + return message + text = new_message.pop('text') + checked_sentence = new_message.get(self.checked_sentence_kword, '') + if isinstance(checked_sentence, list): + checked_sentence = ' '.join(checked_sentence) + chosen_topic = message['chosen_topic'] + + # Get probability of adding wrong knowledge. + p = random.random() + if chosen_topic not in self.distractor_knowledge_sentences: + print(f'Chosen topic "{chosen_topic}" does not have distractor sentences.') + p = 1.0 + if random.random() > p: + # Replace the knowledge with incorrect one. + distractors = list( + set(self.distractor_knowledge_sentences[chosen_topic]) + - set([checked_sentence]) + ) + if distractors: + checked_sentence = random.choice(distractors) + else: + print( + f'Chosen topic "{chosen_topic}" does not have distractor sentences.' + ) + p = 1.0 + + confidence = round(p * 10) + text += f'\n{TOKEN_KNOWLEDGE} {confidence}: {checked_sentence} {TOKEN_END_KNOWLEDGE}' + new_message['text'] = text + + return new_message