This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Added DST Teacher for multiwoz_v22 task #4656
Merged
Merged
Changes from 4 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
2ce0d93
minor typo in docstring
prajjwal1 61aac91
fixed docstring
prajjwal1 414123a
added DST Teacher for multiwoz_v22 task
prajjwal1 2caa4e0
removed comment and fix lint error
prajjwal1 8c217d3
added docstring
prajjwal1 7dddb04
fix lint
prajjwal1 64a07e0
rm ref to parlai_fb
prajjwal1 eb76569
fixing lint
prajjwal1 9fc1ae6
fixing lint
prajjwal1 c9d36a3
default teacher inherits from system teacher
prajjwal1 d6ae09b
updated test file for multiwoz_v22
prajjwal1 5517720
rm multitask_bart
prajjwal1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,24 +5,28 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
implementation for ParlAI. | ||
Multiwoz 2.2 Dataset implementation for ParlAI. | ||
""" | ||
|
||
from parlai.core.params import ParlaiParser | ||
import copy | ||
import json | ||
import os | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from parlai.core.opt import Opt | ||
from parlai_fb.tasks.multiwoz_v22.build import build_dataset, fold_size | ||
|
||
import parlai.core.tod.tod_agents as tod_agents | ||
import parlai.core.tod.tod_core as tod | ||
import json | ||
from typing import Optional | ||
import parlai.tasks.multiwoz_v22.build as build_ | ||
from parlai.core.message import Message | ||
from parlai.core.metrics import AverageMetric | ||
from parlai.core.opt import Opt | ||
from parlai.core.params import ParlaiParser | ||
from parlai.utils.data import DatatypeHelper | ||
from parlai.utils.io import PathManager | ||
|
||
import parlai.tasks.multiwoz_v22.build as build_ | ||
import parlai.core.tod.tod_agents as tod_agents | ||
|
||
|
||
DOMAINS = [ | ||
"attraction", | ||
"bus", | ||
|
@@ -36,6 +40,8 @@ | |
|
||
WELL_FORMATTED_DOMAINS = ["attraction", "bus", "hotel", "restaurant", "train", "taxi"] | ||
|
||
SEED = 42 | ||
|
||
|
||
class MultiwozV22Parser(tod_agents.TodStructuredDataParser): | ||
""" | ||
|
@@ -373,6 +379,212 @@ def get_id_task_prefix(self): | |
return "MultiwozV22" | ||
|
||
|
||
class MultiWOZv22DSTTeacher(tod_agents.TodUserSimulatorTeacher): | ||
BELIEF_STATE_DELIM = " ; " | ||
|
||
domains = [ | ||
"attraction", | ||
"hotel", | ||
"hospital", | ||
"restaurant", | ||
"police", | ||
"taxi", | ||
"train", | ||
] | ||
|
||
named_entity_slots = { | ||
"attraction--name", | ||
"restaurant--name", | ||
"hotel--name", | ||
"bus--departure", | ||
"bus--destination", | ||
"taxi--departure", | ||
"taxi--destination", | ||
"train--departure", | ||
} | ||
|
||
rng = np.random.RandomState(SEED) | ||
|
||
def __init__(self, opt: Opt, shared=None, *args, **kwargs): | ||
self.opt = opt | ||
self.fold = opt["datatype"].split(":")[0] | ||
opt["datafile"] = self.fold | ||
self.dpath = os.path.join(opt["datapath"], "multiwoz_v22") | ||
self.id = "multiwoz_v22" | ||
|
||
if shared is None: | ||
build_dataset(opt) | ||
super().__init__(opt, shared) | ||
|
||
def _load_data(self, fold): | ||
dataset_fold = "dev" if fold == "valid" else fold | ||
fold_path = os.path.join(self.dpath, dataset_fold) | ||
dialogs = [] | ||
for file_id in range(1, fold_size(dataset_fold) + 1): | ||
filename = os.path.join(fold_path, f"dialogues_{file_id:03d}.json") | ||
with PathManager.open(filename, "r") as f: | ||
dialogs += json.load(f) | ||
return dialogs | ||
|
||
def _get_curr_belief_states(self, turn): | ||
belief_states = [] | ||
for frame in turn["frames"]: | ||
if "state" in frame: | ||
if "slot_values" in frame["state"]: | ||
for domain_slot_type in frame["state"]["slot_values"]: | ||
for slot_value in frame["state"]["slot_values"][ | ||
domain_slot_type | ||
]: | ||
domain, slot_type = domain_slot_type.split("-") | ||
belief_state = f"{domain} {slot_type} {slot_value.lower()}" | ||
belief_states.append(belief_state) | ||
return list(set(belief_states)) | ||
|
||
def _extract_slot_from_string(self, slots_string): | ||
""" | ||
Either ground truth or generated result should be in the format: "dom slot_type | ||
slot_val, dom slot_type slot_val, ..., dom slot_type slot_val," and this | ||
function would reformat the string into list: | ||
|
||
["dom--slot_type--slot_val", ... ] | ||
""" | ||
|
||
slots_list = [] | ||
per_domain_slot_lists = {} | ||
named_entity_slot_lists = [] | ||
|
||
# split according to ";" | ||
str_split = slots_string.split(self.BELIEF_STATE_DELIM) | ||
|
||
if str_split[-1] == "": | ||
str_split = str_split[:-1] | ||
|
||
str_split = [slot.strip() for slot in str_split] | ||
|
||
for slot_ in str_split: | ||
slot = slot_.split() | ||
if len(slot) > 2 and slot[0] in self.domains: | ||
domain = slot[0] | ||
slot_type = slot[1] | ||
slot_val = " ".join(slot[2:]) | ||
if not slot_val == "dontcare": | ||
slots_list.append(domain + "--" + slot_type + "--" + slot_val) | ||
if domain in per_domain_slot_lists: | ||
per_domain_slot_lists[domain].add(slot_type + "--" + slot_val) | ||
else: | ||
per_domain_slot_lists[domain] = {slot_type + "--" + slot_val} | ||
if domain + "--" + slot_type in self.named_entity_slots: | ||
named_entity_slot_lists.append( | ||
domain + "--" + slot_type + "--" + slot_val | ||
) | ||
return slots_list, per_domain_slot_lists, named_entity_slot_lists | ||
|
||
def custom_evaluation( | ||
self, teacher_action: Message, labels, model_response: Message | ||
): | ||
""" | ||
for dialog state tracking, we compute the joint goal accuracy, which is the | ||
percentage of the turns where the model correctly and precisely predicts all | ||
slots(domain, slot_type, slot_value). | ||
""" | ||
resp = model_response.get("text") | ||
if not resp: | ||
return | ||
|
||
# extract ground truth from labels | ||
( | ||
slots_truth, | ||
slots_truth_per_domain, | ||
slots_truth_named_entity, | ||
) = self._extract_slot_from_string(labels[0]) | ||
|
||
# extract generated slots from model_response | ||
( | ||
slots_pred, | ||
slots_pred_per_domain, | ||
slots_pred_named_entity, | ||
) = self._extract_slot_from_string(resp) | ||
|
||
for gt_slot in slots_truth: | ||
self.metrics.add("all/slot_r", AverageMetric(gt_slot in slots_pred)) | ||
curr_domain = gt_slot.split("--")[0] | ||
self.metrics.add( | ||
f"{curr_domain}/slot_r", AverageMetric(gt_slot in slots_pred) | ||
) | ||
|
||
for gt_slot in slots_pred_named_entity: | ||
self.metrics.add( | ||
"hallucination", AverageMetric(gt_slot not in slots_truth_named_entity) | ||
) | ||
|
||
for predicted_slot in slots_pred: | ||
self.metrics.add("all/slot_p", AverageMetric(predicted_slot in slots_truth)) | ||
curr_domain = predicted_slot.split("--")[0] | ||
self.metrics.add( | ||
f"{curr_domain}/slot_p", AverageMetric(predicted_slot in slots_truth) | ||
) | ||
|
||
self.metrics.add("jga", AverageMetric(set(slots_truth) == set(slots_pred))) | ||
self.metrics.add( | ||
"named_entities/jga", | ||
AverageMetric( | ||
set(slots_truth_named_entity) == set(slots_pred_named_entity) | ||
), | ||
) | ||
for gt_slot in slots_truth_named_entity: | ||
self.metrics.add("all_ne/slot_r", AverageMetric(gt_slot in slots_pred)) | ||
curr_domain = gt_slot.split("--")[0] | ||
self.metrics.add( | ||
f"{curr_domain}_ne/slot_r", AverageMetric(gt_slot in slots_pred) | ||
) | ||
for predicted_slot in slots_pred_named_entity: | ||
self.metrics.add( | ||
"all_ne/slot_p", AverageMetric(predicted_slot in slots_truth) | ||
) | ||
curr_domain = predicted_slot.split("--")[0] | ||
self.metrics.add( | ||
f"{curr_domain}_ne/slot_p", AverageMetric(predicted_slot in slots_truth) | ||
) | ||
|
||
for domain in slots_truth_per_domain: | ||
if domain in slots_pred_per_domain: | ||
self.metrics.add( | ||
f"{domain}/jga", | ||
AverageMetric( | ||
slots_truth_per_domain[domain] == slots_pred_per_domain[domain] | ||
), | ||
) | ||
|
||
def setup_data(self, fold): | ||
dialogs = self._load_data(fold) | ||
examples = [] | ||
for dialog in dialogs: | ||
context = [] | ||
for turn in dialog["turns"]: | ||
curr_turn = turn["utterance"].lower() | ||
curr_speaker = ( | ||
"<user>" if turn["speaker"].lower() == "user" else "<system>" | ||
) | ||
curr_context = f"{curr_speaker} {curr_turn}" | ||
context.append(curr_context) | ||
cum_belief_states = self._get_curr_belief_states(turn) | ||
if curr_speaker == "<user>": | ||
examples.append( | ||
{ | ||
"dialogue_id": dialog["dialogue_id"], | ||
"turn_num": turn["turn_id"], | ||
"text": " ".join(context), | ||
"labels": self.BELIEF_STATE_DELIM.join( | ||
set(cum_belief_states) | ||
), | ||
} | ||
) | ||
|
||
self.rng.shuffle(examples) | ||
for example in examples: | ||
yield example, True | ||
|
||
|
||
class UserSimulatorTeacher(MultiwozV22Parser, tod_agents.TodUserSimulatorTeacher): | ||
pass | ||
|
||
|
@@ -393,5 +605,5 @@ class SingleApiSchemaAgent(MultiwozV22Parser, tod_agents.TodSingleApiSchemaAgent | |
pass | ||
|
||
|
||
class DefaultTeacher(SystemTeacher): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be better to keep the default to SystemTeacher as is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed this. |
||
class DefaultTeacher(MultiWOZv22DSTTeacher): | ||
pass |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be better to add a comment saying this teacher is needed for reproducing the Joint Goal Accuracy values reported in the simpleTOD & SOLOIST papers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it better now ? @chinnadhurai