This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a project page with sample commands to train, eval, and interact with a "shared" K2R model on WoW. Also added LightQA (=SummaryQA2) agent and a sample command to display the data. Linked K2R project page on the ParlAI project page. Co-authored-by: Leonard Adolphs <ladolphs@devfair0791.h2.fair>
- Loading branch information
Showing
8 changed files
with
1,383 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Reason first, then respond: Modular Generation for Knowledge-infused Dialogue | ||
Leonard Adolphs, Kurt Shuster, Jack Urbanek, Arthur Szlam, Jason Weston | ||
|
||
<b>Paper Link</b>: [https://arxiv.org/abs/2111.05204](https://arxiv.org/abs/2111.05204) | ||
|
||
## Abstract | ||
Large language models can produce fluent dialogue but often hallucinate factual inaccuracies. While retrieval-augmented models help alleviate this issue, they still face a difficult challenge of both reasoning to provide correct knowledge and generating conversation simultaneously. In this work, we propose a modular model, Knowledge to Response (K2R), for incorporating knowledge into conversational agents, which breaks down this problem into two easier steps. K2R first generates a knowledge sequence, given a dialogue context, as an intermediate step. After this "reasoning step", the model then attends to its own generated knowledge sequence, as well as the dialogue context, to produce a final response. In detailed experiments, we find that such a model hallucinates less in knowledge-grounded dialogue tasks, and has advantages in terms | ||
of interpretability and modularity. | ||
In particular, it can be used to fuse QA and dialogue systems together to enable dialogue agents to give knowledgeable answers, or QA models to give conversational responses in a zero-shot setting. | ||
|
||
|
||
## Train a shared K2R model on WoW | ||
``` | ||
parlai train \ | ||
-t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_checked_sentence_as_label,projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_add_checked_sentence_to_input \ | ||
--multitask_weights 1,1 --activation gelu --attention-dropout 0.0 --batchsize 16 --dropout 0.1 --fp16 True --gradient-clip 0.1 --label-truncate 128 \ | ||
--text-truncate 512 --log-every-n-secs 30 --lr-scheduler reduceonplateau --lr-scheduler-patience 1 --max-train-time 169344.0 --model-parallel True \ | ||
--model rag -o arch/bart_large --init-model zoo:bart/bart_large/model --dict-file zoo:bart/bart_large/model.dict --warmup-updates 0 \ | ||
--multitask-weights stochastic --relu-dropout 0.0 --save-after-valid True --skip-generation True -lr 1e-05 -vmm min -veps 0.25 -vme 1000 \ | ||
-vmt ppl -vp 5 --n-docs 5 -tblog True --indexer-type compressed --compressed-indexer-nprobe 128 \ | ||
--model-file ./models/wow/k2r_shared | ||
``` | ||
|
||
## Evaluate the model on WoW | ||
``` | ||
parlai em \ | ||
-t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:random_split \ | ||
-m projects.k2r.stacked_agent.task.agents:StackedKnowledgeDialogueAgent \ | ||
--knowledge-response-model-path ./models/wow/k2r_shared \ | ||
--dialogue-response-model-path ./models/wow/k2r_shared \ | ||
--dialogue-response-no-knowledge-model-path None \ | ||
--dialogue-response-rag-wiki-model-path None \ | ||
--mutators flatten -dt valid --krm-fp16 False --krm-model-parallel False --drm-model-parallel False --krm-beam-min-length 15 \ | ||
--krm-beam-size 3 --krm-indexer-type compressed --krm-compressed-indexer-nprobe 128 --krm-n-docs 5 --drm-beam-size 3 --drm-beam-min-length 20 --batchsize 2 --log-every-n-secs 30 --metrics all | ||
``` | ||
|
||
## Do interactive generations with the model | ||
``` | ||
python projects/k2r/stacked_agent/scripts/stacked_agent_eval.py \ | ||
--task wizard_of_wikipedia:Generator -dt test -bs 1 -n 100 \ | ||
--interactive true --mutators flatten --random-order false --verbose true \ | ||
--drm-beam-context-block-ngram 3 --beam-disregard-knowledge-for-context-blocking false \ | ||
--knowledge-response-model-path ./models/wow/k2r_shared \ | ||
--dialogue-response-model-path ./models/wow/k2r_shared | ||
``` | ||
|
||
## LightQA data | ||
Our goal with LightQA is to have a task that requires a model to answer questions *about the previous context*. For example, in LIGHT, a player might ask another character where to find a certain key to complete their quest. Here, we would want a model, acting as the character, to answer appropriately if the knowledge is in the context description. With this goal in mind, we design a dataset in the following way: First, we take a LightWild episode and use an abstractive summarization model, trained on CNN/Daily Mail and the SAMSum Corpus, to generate a summary. Then we identify all noun chunks, entities, and proper nouns and use them as possible answer candidates. For each answer candidate, we use a T5 question generation model, trained on SQuAD, to generate a possible question given the summary as context. As the last step, we filter the generated questions with a QA model, trained on SQuAD, by checking that it would generate the used answer candidate with access to the summary and question. An episode of our dataset consists of the original LightWild episode (up to a certain turn) and the generated question as the last utterance. Hence, our labels in this dataset are not the usual dialogue responses but short answers. | ||
``` | ||
# Display the data. | ||
parlai dd -t projects.k2r.lightqa.task.agents -dt valid | ||
``` | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import json | ||
import os | ||
|
||
from parlai.core.teachers import DialogTeacher | ||
from parlai.utils.io import PathManager | ||
from parlai.core.message import Message | ||
from parlai.core.metrics import F1Metric, normalize_answer, AverageMetric | ||
|
||
from .build import build | ||
|
||
|
||
class SummaryQATeacher(DialogTeacher): | ||
""" | ||
Teacher for the SummaryQA dataset. | ||
""" | ||
|
||
def __init__(self, opt, shared=None): | ||
self.datatype = opt['datatype'].split(':')[0] | ||
build(opt) | ||
opt['datafile'] = os.path.join( | ||
opt['datapath'], f'lightqa/lightqa-wild-summaryqa2-{self.datatype}.json' | ||
) | ||
self.id = 'summaryqa' | ||
super().__init__(opt, shared) | ||
|
||
def setup_data(self, path): | ||
print('loading: ' + path) | ||
with PathManager.open(path) as data_file: | ||
self.episodes = json.load(data_file) | ||
for ex in self.episodes: | ||
episode_done = ex.pop('episode_done') | ||
yield ex, episode_done | ||
|
||
def custom_evaluation( | ||
self, teacher_action: Message, labels, model_response: Message | ||
): | ||
if 'text' in model_response and model_response['text']: | ||
normalized_response = normalize_answer(model_response['text']) | ||
|
||
if labels: | ||
normalized_labels = [normalize_answer(a) for a in labels] | ||
self.metrics.add( | ||
'norm_f1', | ||
F1Metric.compute(normalized_response, normalized_labels), | ||
) | ||
self.metrics.add( | ||
'norm_em', | ||
AverageMetric(int(normalized_response in normalized_labels)), | ||
) | ||
self.metrics.add( | ||
'kaa', | ||
AverageMetric( | ||
int(any([l in normalized_response for l in normalized_labels])) | ||
), | ||
) | ||
|
||
if 'knowledge_response' in model_response: | ||
# Is the predicted knowledge response in the dialogue response? | ||
self.metrics.add( | ||
'pkaa', | ||
AverageMetric( | ||
int( | ||
normalize_answer(model_response['knowledge_response']) | ||
in normalized_response | ||
) | ||
), | ||
) | ||
|
||
|
||
class DefaultTeacher(SummaryQATeacher): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Download and build the data if it does not exist. | ||
""" | ||
|
||
from parlai.core.build_data import DownloadableFile | ||
import parlai.core.build_data as build_data | ||
import os | ||
from shutil import copyfile | ||
|
||
|
||
RESOURCES = [ | ||
DownloadableFile( | ||
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_train.json', | ||
'lightqa-wild-summaryqa2-train.json', | ||
'0c618e0736317fbb9a688f82777165675b5967ffc5208041da940a3e3a947d25', | ||
zipped=False, | ||
), | ||
DownloadableFile( | ||
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_valid.json', | ||
'lightqa-wild-summaryqa2-valid.json', | ||
'3646ff1e6549ec82588caaf7da998ef18df629cacdde43d8ce813df545aabe6c', | ||
zipped=False, | ||
), | ||
DownloadableFile( | ||
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_test.json', | ||
'lightqa-wild-summaryqa2-test.json', | ||
'70804bd77fe7568326a1e229b3ece578cd1867c3e0e8a14fef23faf4e2032f14', | ||
zipped=False, | ||
), | ||
] | ||
|
||
|
||
def build(opt): | ||
version = 'v1.0.0' | ||
dpath = os.path.join(opt['datapath'], 'lightqa') | ||
|
||
if not build_data.built(dpath, version): | ||
print('[building data: ' + dpath + ']') | ||
if build_data.built(dpath): | ||
# An older version exists, so remove these outdated files. | ||
build_data.remove_dir(dpath) | ||
build_data.make_dir(dpath) | ||
|
||
# Download the data. | ||
for downloadable_file in RESOURCES: | ||
if downloadable_file.url.startswith('/checkpoint'): | ||
copyfile( | ||
downloadable_file.url, | ||
os.path.join(dpath, downloadable_file.file_name), | ||
) | ||
else: | ||
downloadable_file.download_file(dpath) | ||
|
||
# Mark the data as built. | ||
build_data.mark_done(dpath, version) |
Oops, something went wrong.