Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Code for K2R paper. #4828

Merged
merged 4 commits into from
Oct 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions projects/k2r/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Reason first, then respond: Modular Generation for Knowledge-infused Dialogue
Leonard Adolphs, Kurt Shuster, Jack Urbanek, Arthur Szlam, Jason Weston

<b>Paper Link</b>: [https://arxiv.org/abs/2111.05204](https://arxiv.org/abs/2111.05204)

## Abstract
Large language models can produce fluent dialogue but often hallucinate factual inaccuracies. While retrieval-augmented models help alleviate this issue, they still face a difficult challenge of both reasoning to provide correct knowledge and generating conversation simultaneously. In this work, we propose a modular model, Knowledge to Response (K2R), for incorporating knowledge into conversational agents, which breaks down this problem into two easier steps. K2R first generates a knowledge sequence, given a dialogue context, as an intermediate step. After this "reasoning step", the model then attends to its own generated knowledge sequence, as well as the dialogue context, to produce a final response. In detailed experiments, we find that such a model hallucinates less in knowledge-grounded dialogue tasks, and has advantages in terms
of interpretability and modularity.
In particular, it can be used to fuse QA and dialogue systems together to enable dialogue agents to give knowledgeable answers, or QA models to give conversational responses in a zero-shot setting.


## Train a shared K2R model on WoW
```
parlai train \
-t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_checked_sentence_as_label,projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_add_checked_sentence_to_input \
--multitask_weights 1,1 --activation gelu --attention-dropout 0.0 --batchsize 16 --dropout 0.1 --fp16 True --gradient-clip 0.1 --label-truncate 128 \
--text-truncate 512 --log-every-n-secs 30 --lr-scheduler reduceonplateau --lr-scheduler-patience 1 --max-train-time 169344.0 --model-parallel True \
--model rag -o arch/bart_large --init-model zoo:bart/bart_large/model --dict-file zoo:bart/bart_large/model.dict --warmup-updates 0 \
--multitask-weights stochastic --relu-dropout 0.0 --save-after-valid True --skip-generation True -lr 1e-05 -vmm min -veps 0.25 -vme 1000 \
-vmt ppl -vp 5 --n-docs 5 -tblog True --indexer-type compressed --compressed-indexer-nprobe 128 \
--model-file ./models/wow/k2r_shared
```

## Evaluate the model on WoW
```
parlai em \
-t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:random_split \
-m projects.k2r.stacked_agent.task.agents:StackedKnowledgeDialogueAgent \
--knowledge-response-model-path ./models/wow/k2r_shared \
--dialogue-response-model-path ./models/wow/k2r_shared \
--dialogue-response-no-knowledge-model-path None \
--dialogue-response-rag-wiki-model-path None \
--mutators flatten -dt valid --krm-fp16 False --krm-model-parallel False --drm-model-parallel False --krm-beam-min-length 15 \
--krm-beam-size 3 --krm-indexer-type compressed --krm-compressed-indexer-nprobe 128 --krm-n-docs 5 --drm-beam-size 3 --drm-beam-min-length 20 --batchsize 2 --log-every-n-secs 30 --metrics all
```

## Do interactive generations with the model
```
python projects/k2r/stacked_agent/scripts/stacked_agent_eval.py \
--task wizard_of_wikipedia:Generator -dt test -bs 1 -n 100 \
--interactive true --mutators flatten --random-order false --verbose true \
--drm-beam-context-block-ngram 3 --beam-disregard-knowledge-for-context-blocking false \
--knowledge-response-model-path ./models/wow/k2r_shared \
--dialogue-response-model-path ./models/wow/k2r_shared
```

## LightQA data
Our goal with LightQA is to have a task that requires a model to answer questions *about the previous context*. For example, in LIGHT, a player might ask another character where to find a certain key to complete their quest. Here, we would want a model, acting as the character, to answer appropriately if the knowledge is in the context description. With this goal in mind, we design a dataset in the following way: First, we take a LightWild episode and use an abstractive summarization model, trained on CNN/Daily Mail and the SAMSum Corpus, to generate a summary. Then we identify all noun chunks, entities, and proper nouns and use them as possible answer candidates. For each answer candidate, we use a T5 question generation model, trained on SQuAD, to generate a possible question given the summary as context. As the last step, we filter the generated questions with a QA model, trained on SQuAD, by checking that it would generate the used answer candidate with access to the summary and question. An episode of our dataset consists of the original LightWild episode (up to a certain turn) and the generated question as the last utterance. Hence, our labels in this dataset are not the usual dialogue responses but short answers.
```
# Display the data.
parlai dd -t projects.k2r.lightqa.task.agents -dt valid
```

Empty file.
76 changes: 76 additions & 0 deletions projects/k2r/lightqa/task/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os

from parlai.core.teachers import DialogTeacher
from parlai.utils.io import PathManager
from parlai.core.message import Message
from parlai.core.metrics import F1Metric, normalize_answer, AverageMetric

from .build import build


class SummaryQATeacher(DialogTeacher):
"""
Teacher for the SummaryQA dataset.
"""

def __init__(self, opt, shared=None):
self.datatype = opt['datatype'].split(':')[0]
build(opt)
opt['datafile'] = os.path.join(
opt['datapath'], f'lightqa/lightqa-wild-summaryqa2-{self.datatype}.json'
)
self.id = 'summaryqa'
super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with PathManager.open(path) as data_file:
self.episodes = json.load(data_file)
for ex in self.episodes:
episode_done = ex.pop('episode_done')
yield ex, episode_done

def custom_evaluation(
self, teacher_action: Message, labels, model_response: Message
):
if 'text' in model_response and model_response['text']:
normalized_response = normalize_answer(model_response['text'])

if labels:
normalized_labels = [normalize_answer(a) for a in labels]
self.metrics.add(
'norm_f1',
F1Metric.compute(normalized_response, normalized_labels),
)
self.metrics.add(
'norm_em',
AverageMetric(int(normalized_response in normalized_labels)),
)
self.metrics.add(
'kaa',
AverageMetric(
int(any([l in normalized_response for l in normalized_labels]))
),
)

if 'knowledge_response' in model_response:
# Is the predicted knowledge response in the dialogue response?
self.metrics.add(
'pkaa',
AverageMetric(
int(
normalize_answer(model_response['knowledge_response'])
in normalized_response
)
),
)


class DefaultTeacher(SummaryQATeacher):
pass
61 changes: 61 additions & 0 deletions projects/k2r/lightqa/task/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Download and build the data if it does not exist.
"""

from parlai.core.build_data import DownloadableFile
import parlai.core.build_data as build_data
import os
from shutil import copyfile


RESOURCES = [
DownloadableFile(
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_train.json',
'lightqa-wild-summaryqa2-train.json',
'0c618e0736317fbb9a688f82777165675b5967ffc5208041da940a3e3a947d25',
zipped=False,
),
DownloadableFile(
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_valid.json',
'lightqa-wild-summaryqa2-valid.json',
'3646ff1e6549ec82588caaf7da998ef18df629cacdde43d8ce813df545aabe6c',
zipped=False,
),
DownloadableFile(
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_test.json',
'lightqa-wild-summaryqa2-test.json',
'70804bd77fe7568326a1e229b3ece578cd1867c3e0e8a14fef23faf4e2032f14',
zipped=False,
),
]


def build(opt):
version = 'v1.0.0'
dpath = os.path.join(opt['datapath'], 'lightqa')

if not build_data.built(dpath, version):
print('[building data: ' + dpath + ']')
if build_data.built(dpath):
# An older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)

# Download the data.
for downloadable_file in RESOURCES:
if downloadable_file.url.startswith('/checkpoint'):
copyfile(
downloadable_file.url,
os.path.join(dpath, downloadable_file.file_name),
)
else:
downloadable_file.download_file(dpath)

# Mark the data as built.
build_data.mark_done(dpath, version)
Loading