Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Code for K2R paper. (#4828)
Browse files Browse the repository at this point in the history
Added a project page with sample commands to train, eval, and interact with a "shared" K2R model on WoW. Also added LightQA (=SummaryQA2) agent and a sample command to display the data. Linked K2R project page on the ParlAI project page.

Co-authored-by: Leonard Adolphs <ladolphs@devfair0791.h2.fair>
  • Loading branch information
leox1v and Leonard Adolphs authored Oct 19, 2022
1 parent 2e2a2b2 commit 774fce5
Show file tree
Hide file tree
Showing 8 changed files with 1,383 additions and 1 deletion.
2 changes: 1 addition & 1 deletion projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ _Task & models for chitchat with a given persona._

- **SeeKeR:** [[project]](http://parl.ai/projects/seeker) _Modular open source search-augmented language model._

- **Reason first, then respond:** [[paper]](https://arxiv.org/abs/2111.05204) _A modular Generation method for Knowledge-infused Dialogue._
- **Reason first, then respond:** [[project]](https://parl.ai/projects/k2r/) [[paper]](https://arxiv.org/abs/2111.05204) _A modular Generation method for Knowledge-infused Dialogue._

- **Internet-Augmented Dialogue Generation** [[project]](http://parl.ai/projects/sea) [[paper]](https://arxiv.org/abs/2107.07566).
_Utilizing a search-engine for open domain chitchat task & models._
Expand Down
53 changes: 53 additions & 0 deletions projects/k2r/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Reason first, then respond: Modular Generation for Knowledge-infused Dialogue
Leonard Adolphs, Kurt Shuster, Jack Urbanek, Arthur Szlam, Jason Weston

<b>Paper Link</b>: [https://arxiv.org/abs/2111.05204](https://arxiv.org/abs/2111.05204)

## Abstract
Large language models can produce fluent dialogue but often hallucinate factual inaccuracies. While retrieval-augmented models help alleviate this issue, they still face a difficult challenge of both reasoning to provide correct knowledge and generating conversation simultaneously. In this work, we propose a modular model, Knowledge to Response (K2R), for incorporating knowledge into conversational agents, which breaks down this problem into two easier steps. K2R first generates a knowledge sequence, given a dialogue context, as an intermediate step. After this "reasoning step", the model then attends to its own generated knowledge sequence, as well as the dialogue context, to produce a final response. In detailed experiments, we find that such a model hallucinates less in knowledge-grounded dialogue tasks, and has advantages in terms
of interpretability and modularity.
In particular, it can be used to fuse QA and dialogue systems together to enable dialogue agents to give knowledgeable answers, or QA models to give conversational responses in a zero-shot setting.


## Train a shared K2R model on WoW
```
parlai train \
-t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_checked_sentence_as_label,projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:mutators=flatten+wow_add_checked_sentence_to_input \
--multitask_weights 1,1 --activation gelu --attention-dropout 0.0 --batchsize 16 --dropout 0.1 --fp16 True --gradient-clip 0.1 --label-truncate 128 \
--text-truncate 512 --log-every-n-secs 30 --lr-scheduler reduceonplateau --lr-scheduler-patience 1 --max-train-time 169344.0 --model-parallel True \
--model rag -o arch/bart_large --init-model zoo:bart/bart_large/model --dict-file zoo:bart/bart_large/model.dict --warmup-updates 0 \
--multitask-weights stochastic --relu-dropout 0.0 --save-after-valid True --skip-generation True -lr 1e-05 -vmm min -veps 0.25 -vme 1000 \
-vmt ppl -vp 5 --n-docs 5 -tblog True --indexer-type compressed --compressed-indexer-nprobe 128 \
--model-file ./models/wow/k2r_shared
```

## Evaluate the model on WoW
```
parlai em \
-t projects.k2r.wow.task.agents:WizardOfWikipediaGeneratorTeacher:random_split \
-m projects.k2r.stacked_agent.task.agents:StackedKnowledgeDialogueAgent \
--knowledge-response-model-path ./models/wow/k2r_shared \
--dialogue-response-model-path ./models/wow/k2r_shared \
--dialogue-response-no-knowledge-model-path None \
--dialogue-response-rag-wiki-model-path None \
--mutators flatten -dt valid --krm-fp16 False --krm-model-parallel False --drm-model-parallel False --krm-beam-min-length 15 \
--krm-beam-size 3 --krm-indexer-type compressed --krm-compressed-indexer-nprobe 128 --krm-n-docs 5 --drm-beam-size 3 --drm-beam-min-length 20 --batchsize 2 --log-every-n-secs 30 --metrics all
```

## Do interactive generations with the model
```
python projects/k2r/stacked_agent/scripts/stacked_agent_eval.py \
--task wizard_of_wikipedia:Generator -dt test -bs 1 -n 100 \
--interactive true --mutators flatten --random-order false --verbose true \
--drm-beam-context-block-ngram 3 --beam-disregard-knowledge-for-context-blocking false \
--knowledge-response-model-path ./models/wow/k2r_shared \
--dialogue-response-model-path ./models/wow/k2r_shared
```

## LightQA data
Our goal with LightQA is to have a task that requires a model to answer questions *about the previous context*. For example, in LIGHT, a player might ask another character where to find a certain key to complete their quest. Here, we would want a model, acting as the character, to answer appropriately if the knowledge is in the context description. With this goal in mind, we design a dataset in the following way: First, we take a LightWild episode and use an abstractive summarization model, trained on CNN/Daily Mail and the SAMSum Corpus, to generate a summary. Then we identify all noun chunks, entities, and proper nouns and use them as possible answer candidates. For each answer candidate, we use a T5 question generation model, trained on SQuAD, to generate a possible question given the summary as context. As the last step, we filter the generated questions with a QA model, trained on SQuAD, by checking that it would generate the used answer candidate with access to the summary and question. An episode of our dataset consists of the original LightWild episode (up to a certain turn) and the generated question as the last utterance. Hence, our labels in this dataset are not the usual dialogue responses but short answers.
```
# Display the data.
parlai dd -t projects.k2r.lightqa.task.agents -dt valid
```

Empty file.
76 changes: 76 additions & 0 deletions projects/k2r/lightqa/task/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os

from parlai.core.teachers import DialogTeacher
from parlai.utils.io import PathManager
from parlai.core.message import Message
from parlai.core.metrics import F1Metric, normalize_answer, AverageMetric

from .build import build


class SummaryQATeacher(DialogTeacher):
"""
Teacher for the SummaryQA dataset.
"""

def __init__(self, opt, shared=None):
self.datatype = opt['datatype'].split(':')[0]
build(opt)
opt['datafile'] = os.path.join(
opt['datapath'], f'lightqa/lightqa-wild-summaryqa2-{self.datatype}.json'
)
self.id = 'summaryqa'
super().__init__(opt, shared)

def setup_data(self, path):
print('loading: ' + path)
with PathManager.open(path) as data_file:
self.episodes = json.load(data_file)
for ex in self.episodes:
episode_done = ex.pop('episode_done')
yield ex, episode_done

def custom_evaluation(
self, teacher_action: Message, labels, model_response: Message
):
if 'text' in model_response and model_response['text']:
normalized_response = normalize_answer(model_response['text'])

if labels:
normalized_labels = [normalize_answer(a) for a in labels]
self.metrics.add(
'norm_f1',
F1Metric.compute(normalized_response, normalized_labels),
)
self.metrics.add(
'norm_em',
AverageMetric(int(normalized_response in normalized_labels)),
)
self.metrics.add(
'kaa',
AverageMetric(
int(any([l in normalized_response for l in normalized_labels]))
),
)

if 'knowledge_response' in model_response:
# Is the predicted knowledge response in the dialogue response?
self.metrics.add(
'pkaa',
AverageMetric(
int(
normalize_answer(model_response['knowledge_response'])
in normalized_response
)
),
)


class DefaultTeacher(SummaryQATeacher):
pass
61 changes: 61 additions & 0 deletions projects/k2r/lightqa/task/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Download and build the data if it does not exist.
"""

from parlai.core.build_data import DownloadableFile
import parlai.core.build_data as build_data
import os
from shutil import copyfile


RESOURCES = [
DownloadableFile(
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_train.json',
'lightqa-wild-summaryqa2-train.json',
'0c618e0736317fbb9a688f82777165675b5967ffc5208041da940a3e3a947d25',
zipped=False,
),
DownloadableFile(
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_valid.json',
'lightqa-wild-summaryqa2-valid.json',
'3646ff1e6549ec82588caaf7da998ef18df629cacdde43d8ce813df545aabe6c',
zipped=False,
),
DownloadableFile(
'http://parl.ai/downloads/light_project/k2r/light_dialog_wild_summaryqa2_test.json',
'lightqa-wild-summaryqa2-test.json',
'70804bd77fe7568326a1e229b3ece578cd1867c3e0e8a14fef23faf4e2032f14',
zipped=False,
),
]


def build(opt):
version = 'v1.0.0'
dpath = os.path.join(opt['datapath'], 'lightqa')

if not build_data.built(dpath, version):
print('[building data: ' + dpath + ']')
if build_data.built(dpath):
# An older version exists, so remove these outdated files.
build_data.remove_dir(dpath)
build_data.make_dir(dpath)

# Download the data.
for downloadable_file in RESOURCES:
if downloadable_file.url.startswith('/checkpoint'):
copyfile(
downloadable_file.url,
os.path.join(dpath, downloadable_file.file_name),
)
else:
downloadable_file.download_file(dpath)

# Mark the data as built.
build_data.mark_done(dpath, version)
Loading

0 comments on commit 774fce5

Please sign in to comment.