Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TOD][Dataset][Easy] Google SGD in TOD Conversations format #4181

Merged
merged 51 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
e365e48
[TOD] Core converesation structure, serialization, const tokens
Nov 15, 2021
c939174
[Tod] Agents, teacher metrics, and tests for these
Nov 16, 2021
3bf655f
[TOD] Tod json structure to teacher task
Nov 16, 2021
6cb4b86
[TOD] Core converesation structure, serialization, const tokens
Nov 15, 2021
1480def
fix test by adding init folder
Nov 16, 2021
de84801
[Tod] Agents, teacher metrics, and tests for these
Nov 16, 2021
638eb28
[TOD] World, world metrics, script, tests
Nov 16, 2021
0e3f492
hmmm... hoping stacks don't bite me. (change that was kept in upper d…
Nov 16, 2021
0643a62
Merge branch 'simpler_tod_1_core_only' into simpler_tod_2_agents_teac…
Nov 16, 2021
37aced2
minor, remove commented out print
Nov 16, 2021
4f91279
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
b05930f
comment
Nov 16, 2021
5086e85
more comment updates (not sure if it actually helps clarity..)
Nov 16, 2021
1e30035
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Nov 16, 2021
9a25fc5
[TOD][Dataset][Easy] Google SGD in TOD Conversations format
Nov 16, 2021
51ed1a9
Merge branch 'main' into simpler_tod_1_core_only
Nov 16, 2021
a6508be
Merge branch 'simpler_tod_1_core_only' into simpler_tod_2_agents_teac…
Nov 16, 2021
eebc36b
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
3675781
use same version of black as in the pre-commit hook
Nov 16, 2021
086c91c
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
0bc961e
use same version of black as in the pre-commit hook
Nov 16, 2021
ed26407
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Nov 16, 2021
677df09
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 16, 2021
24ee898
black with version from pre-commit hook
Nov 16, 2021
3ca7ae3
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 16, 2021
3145e0e
Shouldn't worry about tod_json being in task_list
Nov 16, 2021
1b2a3fb
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 16, 2021
dfc4989
Merge branch 'main' into simpler_tod_2_agents_teachers
Nov 29, 2021
2f15448
address eric comments; add new readme + more documentation
Nov 30, 2021
abd1c7e
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 30, 2021
5d0197d
minor wording change
Nov 30, 2021
39792a8
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 30, 2021
76bfa89
add more documtnation to world tests (following comment on teacher te…
Nov 30, 2021
73c5c7a
minor comment update
Nov 30, 2021
f6acccb
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Nov 30, 2021
dc4b70e
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 30, 2021
7ab9d70
update to respect actual count of episodes (I think this might have i…
Dec 1, 2021
c6c728d
Merge branch 'main' into simpler_tod_2_agents_teachers
Dec 1, 2021
b3283d0
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Dec 1, 2021
85ab0fd
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Dec 1, 2021
0969aa1
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Dec 1, 2021
0580ff0
Merge branch 'main' into simpler_tod_2_agents_teachers
Dec 2, 2021
e00accf
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Dec 2, 2021
701da8d
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Dec 2, 2021
d519dc2
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Dec 2, 2021
9466144
regen after changing tod teacher logic to respect episode/examples le…
Dec 2, 2021
7b24acf
Merge branch 'main' into simpler_tod_3_world
Dec 18, 2021
e3fa063
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Dec 18, 2021
2384563
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Dec 18, 2021
d9ba7e4
Merge branch 'main' into simpler_tod_5a_google_sgd
Dec 22, 2021
acd6ffe
not sure why this comment keeps not being merged correctly ugh...
Dec 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
354 changes: 195 additions & 159 deletions parlai/tasks/google_sgd/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,192 +8,228 @@
Google The Schema-Guided Dialogue(SGD) Dataset implementation for ParlAI.
"""

import os
import glob
import json
from parlai.core.opt import Opt
from parlai.core.teachers import DialogTeacher
from parlai.utils.misc import warn_once
from parlai.core.message import Message
from parlai.core.metrics import AverageMetric, BleuMetric
from parlai.utils.io import PathManager
import os
from typing import Optional

import parlai.tasks.google_sgd.build as build_
import parlai.core.tod.tod_core as tod
import parlai.core.tod.tod_agents as tod_agents
from parlai.core.tod.tod_core import SerializationHelpers
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt
from parlai.utils.io import PathManager


class Text2API2TextTeacher(DialogTeacher):
"""
Teacher which produces both API calls and NLG responses.
"""
class GoogleSGDParser(tod_agents.TodStructuredDataParser):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
parser = super().add_cmdline_args(parser, partial_opt)
parser.add_argument(
"--delex", type="bool", default=False, help="Delexicalize labels"
)
parser.add_argument(
"--filter-dialogue-by-id",
default="",
type=str,
help="Path to a json file of `dialogue_id`s for which we will filter from. Assumes it will contain a map where the keys are a fold and the value is a list of ids",
)
return parser

def __init__(self, opt: Opt, shared=None):
self.fold = opt['datatype'].split(':')[0]
opt['datafile'] = self.fold
self.dpath = os.path.join(opt['datapath'], 'google_sgd')
self.fold = self.get_fold(opt)
opt["datafile"] = self.fold
self.dpath = os.path.join(opt["datapath"], "google_sgd")
if shared is None:
warn_once(
"Google SGD is a beta dataset, and format may significantly change."
)
# full initialize the teacher as this is not a clone
build_.build(opt)
super().__init__(opt, shared)

def get_fold(self, opt):
return opt["datatype"].split(":")[0]

def _load_data(self, fold):
dataset_fold = 'dev' if fold == 'valid' else fold
dataset_fold = "dev" if fold == "valid" else fold
fold_path = os.path.join(self.dpath, dataset_fold)
schema_file = os.path.join(fold_path, 'schema.json')
with PathManager.open(schema_file, 'r') as f:
schema_file = os.path.join(fold_path, "schema.json")
with PathManager.open(schema_file, "r") as f:
schema_lookup = {}
for schema in json.load(f):
schema_lookup[schema['service_name']] = schema

dialogs = []
for file_id in range(1, build_.fold_size(dataset_fold) + 1):
filename = os.path.join(fold_path, f'dialogues_{file_id:03d}.json')
with PathManager.open(filename, 'r') as f:
dialogs += json.load(f)
return schema_lookup, dialogs

def _get_api_call_and_results(self, sys_turn, schema_lookup):
schema_lookup[schema["service_name"]] = schema

dialogues = []
for filename in glob.glob(f"{fold_path}/dialogues*.json"):
with PathManager.open(filename, "r") as f:
dialogues += json.load(f)

filter_path = self.opt.get("filter_dialogue_by_id", "")
if len(filter_path) > 0:
filtered = []
with open(filter_path) as f:
dialogues_to_get = json.load(f)[fold]
for dialogue in dialogues:
if dialogue["dialogue_id"] in dialogues_to_get:
filtered.append(dialogue)
assert len(filtered) == len(
dialogues_to_get
), f"Different number of dialogues found than requested. Are you sure you've got the right form of Google SGD? Did you filter for dialogue ids correctly? len(filtered) = {len(filtered)}, len(dialogues_to_get) = {len(dialogues_to_get)}"
dialogues = filtered
return schema_lookup, dialogues

def _get_api_call_and_results(self, sys_turn):
api_call = {}
api_resp = {}
for frame in sys_turn['frames']:
if 'service_call' in frame:
for frame in sys_turn["frames"]:
if "service_call" in frame:
# API CALL
method = frame['service_call']['method']
for slot_type, slot_value in frame['service_call'][
'parameters'
for slot_type, slot_value in frame["service_call"][
"parameters"
].items():
api_call[f'{method}.{slot_type}'] = slot_value
assert 'service_results' in frame
if slot_value:
api_call[
f"{slot_type.strip()}"
] = SerializationHelpers.inner_list_join(slot_value)
api_call[tod.STANDARD_API_NAME_SLOT] = frame["service_call"]["method"]
assert "service_results" in frame

# API Resp
if 'actions' in frame:
for action in frame['actions']:
slot_type = action['slot']
slot_value = action['canonical_values']
api_resp[slot_type] = slot_value
if "service_results" in frame:
api_resp = {}
service_results = frame["service_results"]
if len(service_results) > 0:
for key, value in service_results[0].items():
api_resp[key] = SerializationHelpers.inner_list_join(value)
return api_call, api_resp

def custom_evaluation(
self, teacher_action: Message, labels, model_response: Message
):
resp = model_response.get('text')
if not resp:
return

if teacher_action['type'] == 'apicall' and resp.startswith('apicall: '):
gold = teacher_action['slots']
slot_strs = resp[9:].split(' ; ')
parsed = {}
for slot_str in slot_strs:
if ' = ' not in slot_str:
if slot_str != '':
# syntactically invalid generations should count against us
self.metrics.add('slot_p', AverageMetric(0))
continue
name, value = slot_str.split(' = ')
parsed[name] = value

# slot precision
for k, v in parsed.items():
self.metrics.add('slot_p', AverageMetric(v == gold.get(k)))
# slot recall
for k, v in gold.items():
self.metrics.add('slot_r', AverageMetric(v == parsed.get(k)))
elif teacher_action['type'] == 'apiresp':
delex_resp = self._delex(resp, teacher_action['slots'])
delex_label = self._delex(labels[0], teacher_action['slots'])
self.metrics.add(
'delex_bleu', BleuMetric.compute(delex_resp, [delex_label])
)

def _delex(self, text, slots):
delex = text
for slot, values in slots.items():
assert isinstance(values, list)
for value in values:
delex = delex.replace(value, slot)
return delex

def _api_dict_to_str(self, apidict):
return ' ; '.join(f'{k} = {v}' for k, v in apidict.items())

def setup_data(self, fold):
schema_lookup, dialogs = self._load_data(fold)
for dialog in dialogs:
# services = dialog['services']
turns = dialog['turns']
num_turns = len(turns)
for turn_id in range(0, num_turns, 2):
is_first_turn = turn_id == 0

def _get_apis_in_domain(self, schema, domain):
"""
Google SGD includes extra information with the call, so remove these.
"""
result = {}
for intent in schema[domain].get("intents", {}):
here = {}
if "required_slots" in intent and len(intent["required_slots"]) > 0:
here[tod.STANDARD_REQUIRED_KEY] = intent["required_slots"]
if "optional_slots" in intent and len(intent["optional_slots"]) > 0:
here[tod.STANDARD_OPTIONAL_KEY] = intent["optional_slots"]
if "result_slots" in intent:
here["results"] = intent["result_slots"]
result[intent["name"]] = here
return result

def _get_intent_groundinging(self, schema, domains):
"""
Returns map where keys are intents and values are names of required/optional
slots.

We do not care about `result_slots` or default values of optional slots.
"""
result = []
for domain in domains:
apis = self._get_apis_in_domain(schema, domain)
for intent, params in apis.items():
here = {}
here[tod.STANDARD_API_NAME_SLOT] = intent
if tod.STANDARD_REQUIRED_KEY in params:
here[tod.STANDARD_REQUIRED_KEY] = params[tod.STANDARD_REQUIRED_KEY]
if (
tod.STANDARD_OPTIONAL_KEY in params
and len(params[tod.STANDARD_OPTIONAL_KEY]) > 0
):
here[tod.STANDARD_OPTIONAL_KEY] = params[
tod.STANDARD_OPTIONAL_KEY
].keys()
result.append(here)
return result

def _get_all_service_calls(self, turns):
"""
Searches through all turns in a dialogue for any service calls, returns these.
"""
results = []
for turn in turns:
for frame in turn["frames"]:
if "service_call" in frame:
call = frame["service_call"]
item = call["parameters"]
item[tod.STANDARD_API_NAME_SLOT] = call["method"]
results.append(item)
return results

def setup_episodes(self, fold):
"""
Parses Google SGD episodes into TodStructuredEpisode.
"""
schema_lookup, dialogues = self._load_data(fold)
result = []
for dialogue in dialogues:
domains = {s.split("_")[0].strip() for s in dialogue["services"]}
turns = dialogue["turns"]
rounds = []
for turn_id in range(0, len(turns), 2):
user_turn = turns[turn_id]
sys_turn = turns[turn_id + 1]
api_call, api_results = self._get_api_call_and_results(
sys_turn, schema_lookup
api_call, api_results = self._get_api_call_and_results(sys_turn)
r = tod.TodStructuredRound(
user_utt=user_turn["utterance"],
api_call_machine=api_call,
api_resp_machine=api_results,
sys_utt=sys_turn["utterance"],
)
call_str = self._api_dict_to_str(api_call)
resp_str = self._api_dict_to_str(api_results)
if not api_call and not api_results:
# input: user_turn, output: sys_turn
yield {
'text': user_turn['utterance'],
'label': sys_turn['utterance'],
'type': 'text',
}, is_first_turn
elif not api_call and api_results:
yield {
'text': f"{user_turn['utterance']} api_resp: {resp_str}",
'label': sys_turn['utterance'],
'type': 'apiresp',
'slots': api_results,
}, is_first_turn
elif api_call and api_results:
# input: user_turn, output: api_call
yield {
'text': user_turn['utterance'],
'label': f'apicall: {call_str}',
'type': 'apicall',
'slots': api_call,
}, is_first_turn

# system turn, input : api results, output : assistant turn
yield {
'text': f"api_resp: {resp_str}",
'label': sys_turn['utterance'],
'type': 'apiresp',
'slots': api_results,
}, False
else:
assert (
api_call and api_results
), "API call without API results! Check Dataset!"


class Text2TextTeacher(Text2API2TextTeacher):
"""
Text-only teacher (with no API calls or slots)
"""

def setup_data(self, fold):
schema_lookup, dialogs = self._load_data(fold)
for dialog in dialogs:
turns = dialog['turns']
num_turns = len(turns)
for turn_id in range(0, num_turns, 2):
if turn_id == 0:
is_first_turn = True
else:
is_first_turn = False
rounds.append(r)
# Now that we've got the rounds, make the episode
episode = tod.TodStructuredEpisode(
domain=SerializationHelpers.inner_list_join(domains),
api_schemas_machine=self._get_intent_groundinging(
schema_lookup, set(dialogue["services"])
),
goal_calls_machine=self._get_all_service_calls(turns),
rounds=rounds,
delex=self.opt.get("delex"),
extras={"dialogue_id": dialogue["dialogue_id"]},
)
result.append(episode)
# check if the number of episodes should be limited and truncate as required
return result

user_turn = turns[turn_id]
sys_turn = turns[turn_id + 1]
# input: user_turn, output: sys_turn
yield {
'text': user_turn['utterance'],
'label': sys_turn['utterance'],
'type': 'text',
}, is_first_turn
def get_id_task_prefix(self):
return "GoogleSGD"


class SystemTeacher(GoogleSGDParser, tod_agents.TodSystemTeacher):
pass


class DefaultTeacher(SystemTeacher):
pass


class UserSimulatorTeacher(GoogleSGDParser, tod_agents.TodUserSimulatorTeacher):
pass


class StandaloneApiTeacher(GoogleSGDParser, tod_agents.TodStandaloneApiTeacher):
pass


class SingleGoalAgent(GoogleSGDParser, tod_agents.TodSingleGoalAgent):
pass


class GoalAgent(GoogleSGDParser, tod_agents.TodGoalAgent):
pass


class ApiSchemaAgent(GoogleSGDParser, tod_agents.TodApiSchemaAgent):
pass


class UserUttAgent(GoogleSGDParser, tod_agents.TodUserUttAgent):
pass


class DefaultTeacher(Text2API2TextTeacher):
class ApiCallAndSysUttAgent(GoogleSGDParser, tod_agents.TodApiCallAndSysUttAgent):
pass
Loading