From e365e48c16e868efbdd2cbc4ad3568a3ed2021b5 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 15:50:00 -0800 Subject: [PATCH 01/57] [TOD] Core converesation structure, serialization, const tokens --- parlai/core/tod/tod_core.py | 227 ++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 parlai/core/tod/tod_core.py diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py new file mode 100644 index 00000000000..76ee8005a74 --- /dev/null +++ b/parlai/core/tod/tod_core.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) enums and base classes. + +This file defines standard tokens, classes for round and conversation structure, and a serialization class to aid in converting between these. + +See `tod_agents.py` for usage of these classes to generate training data and `tod_world_script.py` for usage of these classes in simulated conversations. +""" +from enum import Enum +from typing import List, Dict +from dataclasses import dataclass, field +from collections.abc import Iterable +from parlai.utils.misc import warn_once + +STANDARD_CALL = "APICALL: " +STANDARD_RESP = "APIRESP: " +STANDARD_SYSTEM_UTTERANCE = "SYSTEM: " +STANDARD_USER_UTTERANCE = "USER: " + +STANDARD_GOAL = "GOAL: " +STANDARD_API_SCHEMAS = "APIS: " + +STANDARD_API_NAME_SLOT = "api_name" +STANDARD_REQUIRED_KEY = "reqArg" +STANDARD_OPTIONAL_KEY = "optArg" +STANDARD_DONE = "[DONE]" + +CONST_SILENCE = "__SILENCE__" + + +class TodAgentType(str, Enum): + USER_UTT_AGENT = "user_utt_model" + API_CALL_AGENT = "api_call_model" + API_RESP_AGENT = "api_resp_model" + SYSTEM_UTT_AGENT = "system_utt_model" + API_SCHEMA_GROUNDING_AGENT = "api_schema_grounding_model" + GOAL_GROUNDING_AGENT = "goal_grounding_model" + + +TOD_AGENT_TYPE_TO_PREFIX = { + TodAgentType.USER_UTT_AGENT: STANDARD_USER_UTTERANCE, + TodAgentType.API_CALL_AGENT: STANDARD_CALL, + TodAgentType.API_RESP_AGENT: STANDARD_RESP, + TodAgentType.SYSTEM_UTT_AGENT: STANDARD_SYSTEM_UTTERANCE, + TodAgentType.API_SCHEMA_GROUNDING_AGENT: STANDARD_API_SCHEMAS, + TodAgentType.GOAL_GROUNDING_AGENT: STANDARD_GOAL, +} + + +@dataclass +class TodStructuredRound: + """ + Dataclass for rounds. + """ + + # Variables set by those using this class + user_utt: str = "" + api_call_machine: Dict = field( + default_factory=dict + ) # Hashmap of slot keys and slot values. Note that STANDARD_API_NAME_SLOT (`api_name`) is expected to be one of the keys here when this is nonempty; simulation metrics wonky without + api_resp_machine: Dict = field(default_factory=dict) + sys_utt: str = "" + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_call_utt: str = field(init=False) + api_resp_utt: str = field(init=False) + + def __post_init__(self): + self.api_call_utt = SerializationHelpers.api_dict_to_str(self.api_call_machine) + self.api_resp_utt = SerializationHelpers.api_dict_to_str(self.api_resp_machine) + if ( + len(self.api_call_machine) > 0 + and STANDARD_API_NAME_SLOT not in self.api_call_machine + ): + warn_once( + f"{STANDARD_API_NAME_SLOT} missing when API Call present. This may cause issues for simulation metrics." + ) + + +@dataclass +class TodStructuredEpisode: + """ + Dataclass for episode-level data. + """ + + # Variables set by those using this class + delex: bool = False # Set to true and this class will handle delexicalizing call + response utterances based on API calls and responses exposed to this class. + domain: str = "" # self-explanatory + api_schemas_machine: List[Dict[str, List]] = field( + default_factory=list + ) # Expected to be a List of Dicts with the API name, required arguments, and optional arguments (specified by consts at the top of this file) as keys + goal_calls_machine: List[Dict[str, str]] = field( + default_factory=list + ) # Machine-formatted API calls + rounds: List[TodStructuredRound] = field(default_factory=list) # self explanatory + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_schemas_utt: str = field(init=False) + goal_calls_utt: str = field(init=False) + + def __post_init__(self): + self.api_schemas_utt = SerializationHelpers.list_of_maps_to_str( + self.api_schemas_machine + ) + self.goal_calls_machine = [ + call for call in self.goal_calls_machine if len(call) > 0 + ] + self.goal_calls_utt = SerializationHelpers.list_of_maps_to_str( + self.goal_calls_machine + ) + # Add a done turn at the end + self.rounds.append(TodStructuredRound(user_utt=STANDARD_DONE)) + if self.delex: + accum_slots = ( + {} + ) # separate since some slot values change as we go. Use this for delex first + cum_slots = self.get_all_slots() + for r in self.rounds: + accum_slots.update(r.api_call_machine) + accum_slots.update(r.api_resp_machine) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, accum_slots) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, cum_slots) + + def get_all_slots(self): + result = {} + for r in self.rounds: + result.update(r.api_call_machine) + result.update(r.api_resp_machine) + return result + + +class SerializationHelpers: + @classmethod + def delex(cls, text, slots): + delex = text + for slot, value in slots.items(): + if isinstance(value, str): + delex = delex.replace(value, f"[{slot}]") + else: + for v in value: + delex = delex.replace(v, f"[{slot}]") + return delex + + @classmethod + def inner_list_join(cls, values): + if isinstance(values, str): + return values + return ", ".join(sorted([v.strip() for v in values])) + + @classmethod + def inner_list_split(cls, s): + return s.split(", ") + + @classmethod + def maybe_inner_list_join(cls, values): + if isinstance(values, str) or isinstance(values, int): + return values + elif isinstance(values, Iterable): + return SerializationHelpers.inner_list_join(values) + else: + raise RuntimeError("invalid type of argument for maybe_inner_list_join") + + @classmethod + def api_dict_to_str(cls, apidict): + """ + Used for API Calls and Responses -> Utterance. + """ + return " ; ".join( + f"{k} = {SerializationHelpers.maybe_inner_list_join(v)}" + for k, v in sorted(apidict.items()) + ) + + @classmethod + def str_to_api_dict(cls, string): + """ + Used for API Call and Response Utterances -> Dict. + """ + slot_strs = string.split(" ; ") + result = {} + for slot_str in slot_strs: + if " = " not in slot_str: + continue + name, value = slot_str.split(" = ", 1) + name = name.strip() + value = value.strip() + result[name] = value + return result + + @classmethod + def outer_list_join(cls, s): + return " | ".join(s) + + @classmethod + def outer_list_split(cls, s): + return s.split(" | ") + + @classmethod + def str_to_list_of_maps(cls, s): + return [ + SerializationHelpers.str_to_api_dict(x) + for x in SerializationHelpers.outer_list_split(s) + ] + + @classmethod + def list_of_maps_to_str(cls, list_of_maps): + return SerializationHelpers.outer_list_join( + [SerializationHelpers.api_dict_to_str(m) for m in list_of_maps] + ) + + @classmethod + def str_to_goals(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) + + @classmethod + def str_to_api_schemas(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) From c939174b44a8f4954a53e7327b7a66add28f7fce Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 20:21:58 -0800 Subject: [PATCH 02/57] [Tod] Agents, teacher metrics, and tests for these See documentation block in `tod_agents.py` --- conftest.py | 1 + parlai/core/tod/teacher_metrics.py | 159 ++++ parlai/core/tod/tod_agents.py | 794 ++++++++++++++++++ parlai/core/tod/tod_test_utils/test_agents.py | 216 +++++ pytest.ini | 1 + tests/tod/test_tod_agents_and_teachers.py | 327 ++++++++ tests/tod/test_tod_teacher_metrics.py | 74 ++ 7 files changed, 1572 insertions(+) create mode 100644 parlai/core/tod/teacher_metrics.py create mode 100644 parlai/core/tod/tod_agents.py create mode 100644 parlai/core/tod/tod_test_utils/test_agents.py create mode 100644 tests/tod/test_tod_agents_and_teachers.py create mode 100644 tests/tod/test_tod_teacher_metrics.py diff --git a/conftest.py b/conftest.py index 7cc1e262461..8970273a460 100644 --- a/conftest.py +++ b/conftest.py @@ -67,6 +67,7 @@ def filter_tests_with_circleci(test_list): ('datatests/', 'data'), ('parlai/tasks/', 'teacher'), ('tasks/', 'tasks'), + ('tod/', 'tod'), ] diff --git a/parlai/core/tod/teacher_metrics.py b/parlai/core/tod/teacher_metrics.py new file mode 100644 index 00000000000..3fc85c7d107 --- /dev/null +++ b/parlai/core/tod/teacher_metrics.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) teacher metrics. +""" +from typing import Optional, List, Dict, Any +from parlai.core.metrics import AverageMetric, BleuMetric, F1Metric, Metric, Metrics + + +class SlotMetrics(Metrics): + """ + Helper container which encapsulates standard slot metrics in task oriented learning + (jga, slot_p, slot_r, etc). + + Due to differences in dialogue representations between tasks, the input is pre- + parsed ground truth and predicted slot dictionaries. + + The 'jga+nlg' metric assumes a balanced set of JGA and NLG scores such that + 2 * Avg(JGA, NLG_BLEU) = Avg(JGA + NLG_BLEU) + The `jga+nlg` metric assumes that `NlgMetrics` is used to calculated the other side. + """ + + def __init__( + self, + teacher_slots: Dict[str, str], + predicted_slots: Dict[str, str], + prefixes: Optional[List] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + # jga and optionally Avg(jga,nlg_bleu) + self.add_with_prefixes("jga", AverageMetric(teacher_slots == predicted_slots)) + if len(teacher_slots) > 0: + self.add_with_prefixes( + "jga_noempty", AverageMetric(teacher_slots == predicted_slots) + ) + else: + self.add_with_prefixes( + "jga_empty", AverageMetric(teacher_slots == predicted_slots) + ) + + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), NlgMetrics class (below) adds NLG-BLEU + self.add("jga+nlg", AverageMetric(teacher_slots == predicted_slots)) + # precision + for pred_slot_name, pred_value in predicted_slots.items(): + slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value) + self.add_with_prefixes("slot_p", slot_p) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_p=slot_p)) + # recall + for teacher_slot_name, teacher_value in teacher_slots.items(): + slot_r = AverageMetric( + predicted_slots.get(teacher_slot_name) == teacher_value + ) + self.add_with_prefixes("slot_r", slot_r) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_r=slot_r)) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +class NlgMetrics(Metrics): + """ + Helper container for generation version of standard metrics (F1, BLEU, ..). + """ + + def __init__( + self, + guess: str, + labels: Optional[List[str]], + prefixes: Optional[List[str]] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + bleu = BleuMetric.compute(guess, labels) + f1 = F1Metric.compute(guess, labels) + self.add_with_prefixes("nlg_bleu", bleu) + self.add_with_prefixes("nlg_f1", f1) + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), SlotMetrics class (above) adds JGA + self.add("jga+nlg", bleu) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +AverageType = Optional[AverageMetric] + + +def _average_type_sum_helper(first: AverageType, second: AverageType) -> AverageType: + """ + Helper to deal with Nones. + + We are "clever" in how we aggregate SlotF1Metrics (See SlotMetrics `__init__`) in + that we add precision and recall values separately, but this means we need to handle + None. + """ + if first is None: + return second + if second is None: + return first + return first + second + + +class SlotF1Metric(Metric): + """ + Metric to keep track of slot F1. + + Keeps track of slot precision and slot recall as running metrics. + """ + + __slots__ = ("_slot_p", "_slot_r") + + @property + def macro_average(self) -> bool: + """ + Indicates whether this metric should be macro-averaged when globally reported. + """ + return True + + def __init__(self, slot_p: AverageType = None, slot_r: AverageType = None): + if not isinstance(slot_p, AverageMetric) and slot_p is not None: + slot_p = AverageMetric(slot_p) + if not isinstance(slot_r, AverageMetric) and slot_r is not None: + slot_r = AverageMetric(slot_r) + self._slot_p = slot_p + self._slot_r = slot_r + + def __add__(self, other: Optional["SlotF1Metric"]) -> "SlotF1Metric": + # NOTE: hinting can be cleaned up with "from __future__ import annotations" when + # we drop Python 3.6 + if other is None: + return self + slot_p = _average_type_sum_helper(self._slot_p, other._slot_p) + slot_r = _average_type_sum_helper(self._slot_r, other._slot_r) + return type(self)(slot_p=slot_p, slot_r=slot_r) + + def value(self) -> float: + if self._slot_p is None or self._slot_r is None: + return float("nan") + else: + slot_p = self._slot_p.value() + slot_r = self._slot_r.value() + if slot_p == 0.0 and slot_r == 0.0: + return float("nan") + else: + return 2 * (slot_p * slot_r) / (slot_p + slot_r) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py new file mode 100644 index 00000000000..f22a8330760 --- /dev/null +++ b/parlai/core/tod/tod_agents.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Agents (used for dumping data) and Teachers (for training models) related to the TOD +conversation setup. + +# Usage + +For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. + +For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. + +Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for usage in generating simulations. + +As a convention, agents and teachers that are inheritable are prefixed with "Tod" whereas those that can be used as-is are not. Similarly, classes and functions that do not need to be exposed outside of this file are prefixed with a single underscore ('_'). + +## Why we do this +These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. + +# Description of Agents + Teachers useful for Simulation +## Teachers for training (generative) models + * TodSystemTeacher + * TodUserSimulatorTeacher + +## Agents for Grounding +For goal grounding for the User for simulation: + * TodGoalAgent + * TodSingleGoalAgent + +For (optional) API schema grounding for the System: + * TodApiSchemaAgent (must be used with `TodGoalAgent` only) + * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) + * EmptyApiSchemaAgent + * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. + +## Agents for mocking APIs: + * StandaloneApiAgent + * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher` + +# Agents for dumping data from a ground truth dataset +The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. + * TodApiCallAndSysUttAgent + * TodApiResponseAgent + * TodUserUttAgent + +For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. + +# Other agents +There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. +""" + +from parlai.core.agents import Agent +from parlai.core.message import Message +from parlai.core.metrics import AverageMetric +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.core.teachers import DialogTeacher +from parlai.utils.distributed import is_distributed, get_rank, num_workers + +import parlai.core.tod.tod_core as tod +from parlai.core.tod.tod_core import SerializationHelpers +from parlai.core.tod.teacher_metrics import SlotMetrics, NlgMetrics + +from typing import Optional, List +import json +import pickle +import difflib +import random +from math import ceil + + +######### Agents that dump information from a dataset; base classes +class TodStructuredDataParser(Agent): + """ + Base class that specifies intermediate representations for Tod conversations. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + if hasattr(super(), "add_cmdline_args"): + parser = super().add_cmdline_args(parser, partial_opt) + group = parser.add_argument_group("TOD StructuredData agent") + group.add_argument( + "--episodes-randomization-seed", + type=int, + default=-1, + help="Randomize episodes in a predictable way (eg, for few shot). Set to -1 for no randomization. ", + ) + parser.add_argument( + "--n-shot", + default=-1, + type=int, + help="Number of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--percent-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + parser.add_argument( + "--percent-shot", + default=-1, + type=float, + help="Percentage of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--n-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + return parser + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.id = self.get_id_task_prefix() + "_" + self._get_agent_type_suffix() + if shared is None: + self.episodes = self.generate_episodes() + else: + self.episodes = shared["episodes"] + + def share(self): + share = super().share() + share["episodes"] = self.episodes + return share + + def setup_episodes(self, fold: str) -> List[tod.TodStructuredEpisode]: + """ + Fold here is a data fold. + """ + raise NotImplementedError( + "Must have method for generating an episode. Must be set in downstream Parser for a given task" + ) + + def generate_episodes(self) -> List[tod.TodStructuredEpisode]: + if self.opt.get("n_shot", -1) >= 0 and self.opt.get("percent_shot", -1) >= 0: + # Validate before spending a while to load eeverything + raise RuntimeError("Both `--n-shot` and `--percent-shot` in use!") + episodes = list(self.setup_episodes(self.fold)) + if self.opt.get("episodes_randomization_seed", -1) != -1: + random.Random(self.opt["episodes_randomization_seed"]).shuffle(episodes) + if self.opt.get("n_shot", -1) != -1: + episodes = episodes[: self.opt["n_shot"]] + elif self.opt.get("percent_shot", -1) >= 0: + episodes = episodes[: int(len(episodes) * self.opt["percent_shot"])] + return episodes + + def get_id_task_prefix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set ID prefix in downstream task agent. Must be set in downsream Parser for a given task" + ) + + def _get_agent_type_suffix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set in downstream agent within `tod_agents`. If you see this error, something is wrong with TOD Infrastructure" + ) + + +######### Agents that dump information from a dataset as gold (explicitly should *not* be used with teachers) +class _TodDataDumpAgent(TodStructuredDataParser): + """ + For agents which dump data from some dataset, without training/other modifications. + + Implements an "epoch done" + + Member variables assumed to be set in init downstream: + self.fold + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + self.batchsize = opt.get("batchsize", 1) + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + self.episode_idx = opt.get("batchindex", 0) + self._setup_next_episode() + self.round_idx = 0 # for some downstream utt + sysUttAndApiCallAgents. + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.episode_idx += rank * chunk_size + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + def _setup_next_episode(self): + self.epochDone = not self.episode_idx < self.max_episodes + self.episode = None + if not self.epochDone: + self.episode = self.episodes[self.episode_idx] + self.round_idx = ( + 0 # so downstream agents know which round they are in. Update in `act()` + ) + + def epoch_done(self) -> bool: + return self.epochDone + + def episode_done(self) -> bool: + """ + This is not actually "episode_done" so much as "we want to signify to the world + that we have gone past the batch". + + This class should not control whether or not the episode is actually done since + the TodWorld expects that to come from the User agent. + """ + return self.epochDone + + def num_episodes(self) -> int: + return len(self.episodes) + + def reset(self): + self.episode_idx += self.batchsize + self._setup_next_episode() + + +class TodGoalAgent(_TodDataDumpAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + """ + + def act(self): + return { + "text": f"{tod.STANDARD_GOAL}{self.episode.goal_calls_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "Goal" + + +class TodApiSchemaAgent(_TodDataDumpAgent): + def act(self): + return { + "text": f"{tod.STANDARD_API_SCHEMAS}{self.episode.api_schemas_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "ApiSchema" + + +############# Single Goal + Api Schema Agent +class _EpisodeToSingleGoalProcessor(_TodDataDumpAgent): + """ + Iterate through all of the goals of a dataset, one by one. + + Slightly different logic than the dump agent since how we count + setup examples for + an episode are different + + Used as a mixin in the SingleGoal and SingleApiSchema agents below. + + This class exposes a `filter_goals()` function that can be overridden by downstream agents. + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + if shared is None: + self.episodes = self._setup_single_goal_episodes() + else: + # Handled fine in _TodDataDumpAgent + pass + + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + self._setup_next_episode() + + def _setup_single_goal_episodes(self) -> List[tod.TodStructuredEpisode]: + """ + This function assumes that `self.setup_episodes()` has already been called + prior. + + Based on the `__init__` order of this class, it should be done in + `TodStructuredDataParser` by this point. + """ + raw_episodes = self.episodes + result = [] + for raw in raw_episodes: + for call in self.filter_goals(raw.goal_calls_machine): + schema = {} + for cand in raw.api_schemas_machine: + if ( + cand[tod.STANDARD_API_NAME_SLOT] + == call[tod.STANDARD_API_NAME_SLOT] + ): + schema = cand + + result.append( + tod.TodStructuredEpisode( + domain=raw.domain, + api_schemas_machine=[schema], + goal_calls_machine=[call], + rounds=[], + ) + ) + return result + + def filter_goals(self, goals): + """ + Some downstream agents may want to filter the goals. + + Override this if so. + """ + return goals + + +class TodSingleGoalAgent(_EpisodeToSingleGoalProcessor, TodGoalAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: If an API schema agent is used, this *must* be used with `TodSingleApiSchemaAgent` since it will be nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleGoal" + + +class TodSingleApiSchemaAgent(_EpisodeToSingleGoalProcessor, TodApiSchemaAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: Must be used with TodSingleGoalAgent since nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleApiSchema" + + +###### Agents used for calculating TOD World Metrics based on a dataset. See `tod_world_script` or `parlai/projects/tod_simulator/` for examples. +class TodUserUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "User" agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_USER_UTTERANCE}{self.episode.rounds[self.round_idx].user_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "User" + + +class TodApiCallAndSysUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "System" + agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def __init__(self, opt: Opt, shared=None): + # This class represents two "agents" so need to make sure we don't increment episode number (reset) twice + self.already_reset = False + self.api_call_turn = True + super().__init__(opt, shared) + + def act(self): + self.already_reset = False + if tod.STANDARD_API_SCHEMAS in self.observation.get("text", ""): + return { + "text": tod.STANDARD_API_SCHEMAS, + "id": self.id, + "domain": self.episode.domain, + "episode_down": False, + } + + if self.api_call_turn: # comes first, don't iterate round # + result = { + "text": f"{tod.STANDARD_CALL}{self.episode.rounds[self.round_idx].api_call_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + else: + result = { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{self.episode.rounds[self.round_idx].sys_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + + self.api_call_turn ^= True + return result + + def reset(self): + if not self.already_reset: + super().reset() # setup next episode + self.api_call_turn = True + self.already_reset = True + + def _get_agent_type_suffix(self): + return "System" + + +class TodApiResponseAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the API + Simulator. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_RESP}{self.episode.rounds[self.round_idx].api_resp_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "ApiResponse" + + +###### Standalone API agent +class StandaloneApiAgent(Agent): + """ + Trainable agent that saves API calls and responses. + + Use `TodStandaloneApiTeacher` to train this class. For example for a MultiWoz V2.2 + standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the + `.pickle` file to use. + """ + + EMPTY_RESP = { + "text": tod.STANDARD_RESP, + "id": "StandaloneApiAgent", + "episode_done": False, + } + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + group = parser.add_argument_group("TOD Standalone API args") + group.add_argument( + "--exact-api-call", + type=bool, + default=True, + help="Validation-time flag. If true, will return '' if exact api call values not found. If false, will pick response from the same intent with similar api parameters (assuming intent is the same when available)", + ) + + group.add_argument( + "--fail-hard", + type=bool, + default=False, + help="Aids in deugging. Will throw exception if API call not found and '--exact-api-call' is set.", + ) + + group.add_argument( + "--standalone-api-file", + type=str, + default=None, + help="Path to file holding `.pickle` of standalone api for validation (will intelligently strip if suffix included). If not set, assumes the `model_file` argument will contain the `.pickle` file. ", + ) + return parser + + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self.id = "StandaloneApiAgent" + file_key = "model_file" + if self.opt["standalone_api_file"] is not None: + file_key = "standalone_api_file" + self.path_base = self.opt[file_key].replace(".pickle", "") + self.db_path = self.path_base + ".pickle" + self.exact_api_call = self.opt["exact_api_call"] + try: + with (open(self.db_path, "rb")) as openfile: + self.data = pickle.load(openfile) + self.training = True + print("Loaded Standalone API data successfully") + if self.exact_api_call != self.data.get("exact_api_call", True): + raise RuntimeError( + f"Standalone API .pickle file generated with `exact_api_call` of {self.data.get('exact_api_call', False)} but StandaloneApiAgent sets it to {self.exact_api_call}" + ) + except Exception: + print(f"No file at {self.db_path}; ASSUMING WE ARE TRAINING") + self.data = {} + self.data["exact_api_call"] = self.exact_api_call + self.training = True + + def _maybe_filter_prefix(self, text, prefix): + if prefix in text: + return text[len(prefix) :].strip() + return text.strip() + + def act(self): + if not self.observation["text"].startswith(tod.STANDARD_CALL): + return self.EMPTY_RESP + call_text_raw = self.observation["text"] + # decode then reencode the API call so that we get the API calls in a consistent order + call_text = SerializationHelpers.api_dict_to_str( + SerializationHelpers.str_to_api_dict( + call_text_raw[len(tod.STANDARD_CALL) :] + ) + ) + if "labels" in self.observation: + return self._do_train(call_text) + return self._do_fetch(call_text) + + def _do_train(self, call_text): + assert self.training is True + self.data[call_text] = self.observation["labels"][0] + return self.EMPTY_RESP + + def _do_fetch(self, call_text): + if self.exact_api_call: + if self.opt.get("fail_hard", False): + resp = self.data[call_text] + else: + resp = self.data.get(call_text, tod.STANDARD_RESP) + return { + "text": resp, + "id": self.id, + "episode_done": False, + } + + # Not exact case + best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) + if len(best_key) == 0: + return self.EMPTY_RESP + return { + "text": self.data.get(best_key[0], tod.STANDARD_RESP), + "id": self.id, + "episode_done": False, + } + + def shutdown(self): + if self.training: + with (open(self.db_path, "wb")) as openfile: + pickle.dump(self.data, openfile) + print(f"Dumped output to {self.db_path}") + with open(self.path_base + ".opt", "w") as f: + json.dump(self.opt, f) + + +######### Empty agents +class EmptyApiSchemaAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyApiSchemaAgent" + + def act(self): + msg = { + "id": self.getID(), + "text": tod.STANDARD_API_SCHEMAS, + "episode_done": False, + } + return Message(msg) + + +class EmptyGoalAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyGoalAgent" + + def act(self): + msg = {"id": self.getID(), "text": tod.STANDARD_GOAL, "episode_done": False} + return Message(msg) + + +############# Teachers +class TodSystemTeacher(TodStructuredDataParser, DialogTeacher): + """ + TOD agent teacher which produces both API calls and NLG responses. + + First turn is API Schema grounding, which may be a an empty schema. + Subsequent turns alternate between + 1. User utterance -> API Call + 2. API Response -> System Utterance + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--api-schemas", + type="bool", + default=False, + help="Preempt first turn with intents + required/optional parameters as key/value for given domain", + ) + parser.add_argument( + "--api-jga-record", + type=bool, + default=True, + help="Should we save jga information per api schema?", + ) + parser.add_argument( + "--domain-jga-record", + type=bool, + default=False, + help="Should we save jga information per domain?", + ) + parser.add_argument( + "--domain-nlg-record", + type=bool, + default=False, + help="Should we save nlg information per domain?", + ) + return parser + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_CALL: + if resp.startswith(tod.STANDARD_CALL): + resp = resp[len(tod.STANDARD_CALL) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + domains = ( + [teacher_action["domain"]] if self.opt["domain_jga_record"] else [] + ) + + metrics = SlotMetrics( + teacher_slots=teacher_action["slots"], + predicted_slots=predicted, + avg_jga_nlg_bleu=True, + prefixes=domains, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + if self.opt["api_jga_record"] and len(teacher_action["slots"]) > 0: + teacher = teacher_action["slots"] + slots = list(teacher.keys()) + slots.remove(tod.STANDARD_API_NAME_SLOT) + api_here = ( + "api-" + + teacher[tod.STANDARD_API_NAME_SLOT] + + "--" + + "-".join(slots) + ) + self.metrics.add(f"{api_here}/jga", AverageMetric(teacher == predicted)) + + elif teacher_action["type"] == tod.STANDARD_SYSTEM_UTTERANCE: + domains = ( + [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] + ) + metrics = NlgMetrics( + guess=resp, + labels=labels, + prefixes=domains, + avg_jga_nlg_bleu=True, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if self.opt.get("api_schemas"): + schemas = episode.api_schemas_utt + else: + schemas = "" + yield { + "text": f"{tod.STANDARD_API_SCHEMAS}{schemas}", + "label": f"{tod.STANDARD_API_SCHEMAS}", + "domain": episode.domain, + "type": tod.STANDARD_API_SCHEMAS, + "slots": {}, + }, True + for r in episode.rounds: + yield { + "text": f"{tod.STANDARD_USER_UTTERANCE}{r.user_utt}", + "label": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "domain": episode.domain, + "type": tod.STANDARD_CALL, + "slots": r.api_call_machine, + }, False + yield { + "text": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "label": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "domain": episode.domain, + "slots": r.api_resp_machine, + "type": tod.STANDARD_SYSTEM_UTTERANCE, + }, False + + def _get_agent_type_suffix(self): + return "SystemTeacher" + + +class TodUserSimulatorTeacher(TodStructuredDataParser, DialogTeacher): + """ + Teacher that has `Goal->User Utterance` for its first turn, then `System + Utterance->User Utterance` for all subsequent turns. + """ + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if len(episode.rounds) < 1: + continue + yield { + "text": f"{tod.STANDARD_GOAL}{episode.goal_calls_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[0].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + }, True + for i, r in enumerate(episode.rounds): + if i == len(episode.rounds) - 1: + continue + yield { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[i+1].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + "slots": {}, # slots in agent/user turns are meaningless + }, False + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_RESP: + if resp.startswith(tod.STANDARD_RESP): + resp = resp[len(tod.STANDARD_RESP) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + + metrics = SlotMetrics(teacher_action["slots"], predicted).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + elif teacher_action["type"] == tod.STANDARD_USER_UTTERANCE: + metrics = NlgMetrics(resp, labels).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def _get_agent_type_suffix(self): + return "UserSimulatorTeacher" + + +class TodStandaloneApiTeacher(TodStructuredDataParser, DialogTeacher): + """ + Use this to generate a database for `StandaloneApiAgent`. + + Set this as the teacher with `StandaloneApiAgent` as the agent. Ex for a MultiWoz + V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` + """ + + def setup_data(self, fold): + # As a default, just put everything in + for fold_overwrite in ["train", "valid", "test"]: + for episode in self.setup_episodes(fold_overwrite): + first = True + for r in episode.rounds: + if len(r.api_call_machine) > 0: + yield { + "text": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "label": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "id": self.id, + "domain": episode.domain, + }, first + first = False + + def _get_agent_type_suffix(self): + return "StandaloneApiTeacher" diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py new file mode 100644 index 00000000000..b1339052764 --- /dev/null +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helpers so we don't need to create agents all over. +""" + +import parlai.core.tod.tod_agents as tod_agents +import parlai.core.tod.tod_core as tod_core + +import os + +API_DATABASE_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "standalone_api_file.pickle" +) + + +def episode_has_broken_api_turn(episode_idx, max_turns): + return episode_idx % 2 == 1 and max_turns > 0 + + +def use_broken_api_calls_this_turn(round_idx, episode_idx): + return episode_idx % 2 == 1 and round_idx % 3 == 1 + + +def make_api_call_machine(round_idx, episode_idx=0, use_broken_mock_api_calls=False): + if round_idx == 0: + return {} + if use_broken_mock_api_calls: + # Hack as a way to test metrics reporting in tod world script + if use_broken_api_calls_this_turn(round_idx, episode_idx): + round_idx = -1 * round_idx + return {tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", "in": round_idx} + + +def make_api_resp_machine(round_idx): + if round_idx == 0: + return {} + return {"out": round_idx} + + +def make_api_schemas_machine(max_rounds): + return [ + { + tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", + tod_core.STANDARD_REQUIRED_KEY: ["in"], + tod_core.STANDARD_OPTIONAL_KEY: [], + } + for round_idx in range(1, max_rounds) + ] + + +def make_goal_calls_machine(max_rounds): + return [make_api_call_machine(x) for x in range(1, max_rounds)] + + +def get_rounds(episode_idx, max_rounds, use_broken_mock_api_calls=False): + return [ + tod_core.TodStructuredRound( + user_utt=f"user_utt_{episode_idx}_{round_idx}", + api_call_machine=make_api_call_machine( + round_idx, episode_idx, use_broken_mock_api_calls + ), + api_resp_machine=make_api_resp_machine(round_idx), + sys_utt=f"sys_utt_{episode_idx}_{round_idx}", + ) + for round_idx in range(max_rounds) + ] + + +def get_round_utts(episode_idx, max_rounds, filter_utts=None): + if max_rounds < 1: + return [] + utts = [ + [ + f"USER: user_utt_{episode_idx}_0", + "APICALL: ", + "APIRESP: ", + f"SYSTEM: sys_utt_{episode_idx}_0", + ] + ] + for i in range(1, max_rounds): + utts.append( + [ + f"USER: user_utt_{episode_idx}_{i}", + f"APICALL: api_name = name_{i} ; in = {i}", + f"APIRESP: out = {i}", + f"SYSTEM: sys_utt_{episode_idx}_{i}", + ] + ) + utts.append( + [ + "USER: [DONE]", + "APICALL: ", + "APIRESP: ", + "SYSTEM: ", + ] + ) + if filter_utts is not None: + utts = [ + [turn for i, turn in enumerate(round_data) if filter_utts[i]] + for round_data in utts + ] + return utts + + +TEST_NUM_EPISODES_OPT_KEY = "test_num_episodes" +TEST_NUM_ROUNDS_OPT_KEY = "test_num_rounds" + +# No api calls in this setup +EPISODE_SETUP__UTTERANCES_ONLY = { + TEST_NUM_ROUNDS_OPT_KEY: 1, + TEST_NUM_EPISODES_OPT_KEY: 1, +} + +# No one call, one goal, one api desscription in this setup +EPISODE_SETUP__SINGLE_API_CALL = { + TEST_NUM_ROUNDS_OPT_KEY: 2, + TEST_NUM_EPISODES_OPT_KEY: 1, +} +# Will start testing multiple api calls + schemas, multi-round logic +EPISODE_SETUP__MULTI_ROUND = {TEST_NUM_ROUNDS_OPT_KEY: 5, TEST_NUM_EPISODES_OPT_KEY: 1} + +# Test that episode logic is correct +EPISODE_SETUP__MULTI_EPISODE = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 8, +} + +# Test that episode + pesky-off-by-one batchinglogic is correct +EPISODE_SETUP__MULTI_EPISODE_BS = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 35, +} + + +class TestDataParser(tod_agents.TodStructuredDataParser): + """ + Assume that when we init, we init w/ num of episodes + rounds as opts. + """ + + def __init__(self, opt, shared=None): + opt["datafile"] = "DUMMY" + self.fold = "DUMMY" + # Following lines are only reelvant in training the standalone api teacher + if TEST_NUM_EPISODES_OPT_KEY not in opt: + opt[TEST_NUM_EPISODES_OPT_KEY] = 35 + if TEST_NUM_ROUNDS_OPT_KEY not in opt: + opt[TEST_NUM_ROUNDS_OPT_KEY] = 5 + super().__init__(opt, shared) + + def setup_episodes(self, _): + result = [] + for ep_idx in range(0, self.opt[TEST_NUM_EPISODES_OPT_KEY]): + result.append( + tod_core.TodStructuredEpisode( + goal_calls_machine=[ + make_api_call_machine(x) + for x in range(1, self.opt[TEST_NUM_ROUNDS_OPT_KEY]) + ], + api_schemas_machine=make_api_schemas_machine( + self.opt[TEST_NUM_ROUNDS_OPT_KEY] + ), + rounds=get_rounds( + ep_idx, + self.opt[TEST_NUM_ROUNDS_OPT_KEY], + self.opt.get("use_broken_mock_api_calls", False), + ), + ) + ) + # print(result, self.opt) + return result + + def get_id_task_prefix(self): + return "Test" + + +class SystemTeacher(TestDataParser, tod_agents.TodSystemTeacher): + pass + + +class UserSimulatorTeacher(TestDataParser, tod_agents.TodUserSimulatorTeacher): + pass + + +class StandaloneApiTeacher(TestDataParser, tod_agents.TodStandaloneApiTeacher): + pass + + +class GoalAgent(TestDataParser, tod_agents.TodGoalAgent): + pass + + +class ApiSchemaAgent(TestDataParser, tod_agents.TodApiSchemaAgent): + pass + + +class SingleGoalAgent(TestDataParser, tod_agents.TodSingleGoalAgent): + pass + + +class SingleApiSchemaAgent(TestDataParser, tod_agents.TodSingleApiSchemaAgent): + pass + + +# Tested in tod world code +class UserUttAgent(TestDataParser, tod_agents.TodUserUttAgent): + pass + + +# Tested in tod world code +class ApiCallAndSysUttAgent(TestDataParser, tod_agents.TodApiCallAndSysUttAgent): + pass diff --git a/pytest.ini b/pytest.ini index 9aec92c0a89..d4095288194 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,3 +12,4 @@ markers = unit internal nofbcode + tod diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py new file mode 100644 index 00000000000..5383c72416d --- /dev/null +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests different (more complicated) slot metrics. +""" + +import unittest + +import copy +import parlai.core.tod.tod_core as tod_core +import parlai.core.tod.tod_test_utils.test_agents as test_agents + + +class TestTodAgentsAndTeachersBase(unittest.TestCase): + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + full_opts["episodes_randomization_seed"] = -1 # no random here + return class_type(full_opts) + + def dump_single_utt_per_episode_agent_text(self, class_type, round_opt, opt): + agent = self.setup_agent_or_teacher(class_type, round_opt, opt) + result = [] + while not agent.epoch_done(): + result.append(agent.act()["text"]) + agent.reset() + return result + + def dump_teacher_text(self, class_type, round_opt, opt): + """ + Array where [episode_idx][turn_idx][text=0,label=1] + """ + teacher = self.setup_agent_or_teacher(class_type, round_opt, opt) + data = [] + here = [] + for x, new in teacher.setup_data("dummy"): + if new and len(here) > 0: + data.append(copy.deepcopy(here)) + here = [] + here.append([x["text"], x["label"]]) + if len(here) > 0: + data.append(here) + return data + + def _test_roundDataCorrect(self): + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__UTTERANCES_ONLY) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + + +class TestSystemTeacher(TestTodAgentsAndTeachersBase): + def test_apiSchemas_with_yesApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": True}, + ) + self.assertEqual( + values[0][0][0], + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(2) + ), + ) + + def test_apiSchemas_with_noApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": False}, + ) + self.assertEqual(values[0][0][0], "APIS: ") + + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.SystemTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + for utt in utts: + comp.append([utt[0], utt[1]]) + comp.append([utt[2], utt[3]]) + # Skip context turn cause we check it above + self.assertEqual(episode[1:], comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestUserTeacher(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.UserSimulatorTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + comp.append( + [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ), + utts[0][0], + ] + ) + last_sys = utts[0][3] + for i in range(1, len(utts)): + comp.append([last_sys, utts[i][0]]) + last_sys = utts[i][3] + self.assertEqual(episode, comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.GoalAgent, config, {} + ) + + goal_text = [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.ApiSchemaAgent, config, {} + ) + + apis_texts = [ + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, apis_texts) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + + goal_text = [] + for _ in range(max_episodes): + goals = test_agents.make_goal_calls_machine(max_rounds) + for x in goals: + goal_text.append( + "GOAL: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + apis_text = [] + for _ in range(max_episodes): + apis = test_agents.make_api_schemas_machine(max_rounds) + for x in apis: + apis_text.append( + "APIS: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + self.assertEqual(values, apis_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalWithSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + """ + Make sure the SingleGoal + SingleApiSchema agents correspond. + """ + + def _test_roundDataCorrect_helper(self, config): + goals = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + apis = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + for i in range(len(goals)): + goal = tod_core.SerializationHelpers.str_to_goals(goals[i][len("GOALS:") :]) + api = tod_core.SerializationHelpers.str_to_api_schemas( + apis[i][len("APIS:") :] + ) + self.assertEqual( + goal[0].get("api_name", None), api[0].get("api_name", None) + ) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestLowShot(TestTodAgentsAndTeachersBase): + FEW_SHOT_SAMPLES = [0, 1, 5, 15] + PERCENTAGES = [0, 0.1, 0.3, 0.5] + + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + return class_type(full_opts) + + def test_few_shot_lengths_correct(self): + def helper(n_shot): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, + { + "episodes_randomization_seed": 0, + "n_shot": n_shot, + }, + ) + self.assertEqual(len(values), n_shot) + + for i in self.FEW_SHOT_SAMPLES: + helper(i) + + def _test_subsets(self, data_dumps): + for i in range(len(data_dumps) - 1): + small = data_dumps[i] + larger = data_dumps[i + 1] + for i, episode in enumerate(small): + self.assertEqual(episode, larger[i]) + + def test_few_shot_subset(self): + def helper(n_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE, + { + "episodes_randomization_seed": seed, + "n_shot": n_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_three) + self.assertNotEqual(data_dumps_seed_zero[-1], data_dumps_seed_three[-1]) + + def test_percent_shot_lengths_correct(self): + def helper(percent_shot, correct): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": 0, + "percent_shot": percent_shot, + }, + ) + self.assertEqual(len(values), correct) + + helper(0, 0) + helper(0.1, 3) + helper(0.3, 10) + + def test_percent_shot_subset(self): + def helper(percent_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": seed, + "percent_shot": percent_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_three) + + def test_correct_throw_when_both_shots_defined(self): + self.assertRaises( + RuntimeError, + self.dump_teacher_text, + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + {"episodes_randomization_seed": 0, "percent_shot": 0.3, "n_shot": 3}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py new file mode 100644 index 00000000000..aec4aba40f6 --- /dev/null +++ b/tests/tod/test_tod_teacher_metrics.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from math import isnan + +from parlai.core.metrics import AverageMetric +from parlai.core.tod.teacher_metrics import SlotF1Metric, SlotMetrics + + +class TestSlotF1Metric(unittest.TestCase): + """ + Test SlotF1Metric. + """ + + def test_slot_f1_metric_inputs(self): + slots_p_r_and_f1 = [ + (None, None, float("nan")), + (None, AverageMetric(0.0), float("nan")), + (AverageMetric(0.0), AverageMetric(0.0), float("nan")), + (AverageMetric(1), AverageMetric(1), 1.0), + (AverageMetric(1), AverageMetric(0), 0.0), + (AverageMetric(0.25), AverageMetric(0.75), 0.375), + ] + for slot_p, slot_r, slot_f1 in slots_p_r_and_f1: + actual_slot_f1 = SlotF1Metric(slot_p=slot_p, slot_r=slot_r).value() + if isnan(slot_f1): + self.assertTrue(isnan(actual_slot_f1)) + else: + self.assertEqual(slot_f1, actual_slot_f1) + + def test_slot_f1_metric_addition(self): + a = SlotF1Metric(slot_p=1) + b = SlotF1Metric(slot_r=0) + c = SlotF1Metric(slot_p=AverageMetric(numer=2, denom=3), slot_r=1) + d = a + b + c + # Slot P should be 3/4 = 0.75; slot R should be 1/2 = 0.5 + self.assertEqual(0.6, d.value()) + + +empty_slots = {} +basic_slots = {"a": "a_val", "b": "b_val", "c": "c_val"} +partial_slots = {"a": "a_val", "other": "other_val"} + + +class TestSlotMetrics(unittest.TestCase): + def test_base_slot_metrics(self): + cases = [ + (empty_slots, empty_slots, {"jga": 1}), + ( + basic_slots, + basic_slots, + {"jga": 1, "slot_p": 1, "slot_r": 1, "slot_f1": 1}, + ), + ( + basic_slots, + partial_slots, + {"jga": 0, "slot_p": 0.5, "slot_r": float(1.0 / 3), "slot_f1": 0.4}, + ), + ] + for teacher, predicted, result in cases: + metric = SlotMetrics( + teacher_slots=teacher, + predicted_slots=predicted, + ) + for key in result: + self.assertEqual(result[key], metric.report()[key]) + + +if __name__ == "__main__": + unittest.main() From 3bf655fa156bb13928a4d72b7f612127cfeefcb2 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 09:19:16 -0800 Subject: [PATCH 03/57] [TOD] Tod json structure to teacher task As noted in the README, this agent takes data generated from `tod_world_script.py` and dumps it out to a teacher. (Note that I tried setting up a regression test for this teacher, but I ran into issues getting it to save the output directory to not be something that included my local homedir name in it..) --- parlai/tasks/tod_json/README.md | 11 ++ parlai/tasks/tod_json/__init__.py | 5 + parlai/tasks/tod_json/agents.py | 188 +++++++++++++++++++++++ parlai/tasks/tod_json/example_data.jsonl | 15 ++ 4 files changed, 219 insertions(+) create mode 100644 parlai/tasks/tod_json/README.md create mode 100644 parlai/tasks/tod_json/__init__.py create mode 100644 parlai/tasks/tod_json/agents.py create mode 100644 parlai/tasks/tod_json/example_data.jsonl diff --git a/parlai/tasks/tod_json/README.md b/parlai/tasks/tod_json/README.md new file mode 100644 index 00000000000..f96d41e7448 --- /dev/null +++ b/parlai/tasks/tod_json/README.md @@ -0,0 +1,11 @@ +# TOD Json Task Agent + +Takes a .jsonl conversation output from model-model chats from `tod_world_script.py`, puts it into the TOD intermediate conversations format so we can use it in a variety of different teachers. + +For example, to see the display of the data: +``` +parlai dd -t tod_json:SystemTeacher --jsonfile-datapath example_data.jsonl +parlai dd -t tod_json:UserSimulatorTeacher --jsonfile-datapath example_data.jsonl +``` + +See the file `example_data.json` in this directory for the format. diff --git a/parlai/tasks/tod_json/__init__.py b/parlai/tasks/tod_json/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/tasks/tod_json/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py new file mode 100644 index 00000000000..55e3514de6c --- /dev/null +++ b/parlai/tasks/tod_json/agents.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# This task simply loads the specified file: useful for quick tests without +# setting up a new task. + +from typing import Optional +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.utils.data import DatatypeHelper + +import parlai.core.tod.tod_agents as tod_agents +import parlai.core.tod.tod_core as tod + +import json +import os + +PREFIXES = [ + tod.STANDARD_USER_UTTERANCE, + tod.STANDARD_CALL, + tod.STANDARD_RESP, + tod.STANDARD_SYSTEM_UTTERANCE, +] + +PREFIXES_PREEMPT = [ + tod.STANDARD_API_SCHEMAS, + tod.STANDARD_API_SCHEMAS, + tod.STANDARD_API_SCHEMAS, + tod.STANDARD_GOAL, +] + + +class JsonTodParser(tod_agents.TodStructuredDataParser): + """ + This module provides access to data in the TOD conversations format. + + See core/tod.py for more info about the format. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + super().add_cmdline_args(parser, partial_opt) + agent = parser.add_argument_group("Tod Json Task Arguments") + agent.add_argument( + "-jfdp", + "--jsonfile-datapath", + type=str, + help="Data file. (Assumed to be in .jsonl)", + ) + agent.add_argument( + "-tmdp", + "--tod-metrics-datapath", + type=str, + default=None, + help="Filter which examples to use from a report including per-turn tod metrics", + ) + agent.add_argument( + "-f-agh", + "--filter-all-goals-hit", + type=bool, + default=False, + help="Filter episodes by all-goals-hit metric being 1. Assumes `tod-metrics-datapath` is set.", + ) + agent.add_argument( + "--split-to-folds", + type=bool, + default=True, + help="Use all data or split into 8:1:1 fold", + ) + agent.add_argument( + "--split-folds-seed", + type=int, + default=42, + help="Seed for the fold split", + ) + return parser + + def __init__(self, opt, shared=None): + if not opt.get("jsonfile_datapath"): + raise RuntimeError("jsonfile_datapath not specified") + if not hasattr(self, "opt"): + self.opt = opt + self.opt["datafile"] = opt["jsonfile_datapath"] + self.fold = self.opt["datatype"] # don't care + # Truncate datafile to just the immediate enclosing folder name and file name + dirname, basename = os.path.split(self.opt["datafile"]) + self.id = os.path.join(os.path.split(dirname)[1], basename) + super().__init__(opt, shared) + + def _process_line(self, line): + blob = json.loads(line) + if "dialog" not in blob or len(blob["dialog"]) < 1: + return None + rounds = [] + for raw_round in blob["dialog"][1:]: + if "prefix_stripped_text" not in raw_round[0]: + for i in range(len(raw_round)): + raw_round[i]["prefix_stripped_text"] = raw_round[i].get( + "text", PREFIXES[i] + )[len(PREFIXES[i]) :] + if len(raw_round) != 4: + if raw_round[0]["prefix_stripped_text"] != tod.STANDARD_DONE: + return None # misformatted convo, don't learn this. + break # TodStructuredEpisode will add in [DONE] + r = tod.TodStructuredRound( + user_utt=raw_round[0]["prefix_stripped_text"], + api_call_machine=tod.SerializationHelpers.str_to_api_dict( + raw_round[1]["prefix_stripped_text"] + ), + api_resp_machine=tod.SerializationHelpers.str_to_api_dict( + raw_round[2]["prefix_stripped_text"] + ), + sys_utt=raw_round[3]["prefix_stripped_text"], + ) + rounds.append(r) + preempt_round = blob["dialog"][0] + if len(preempt_round) != 4: + return None + for i in range(len(preempt_round)): + if "prefix_stripped_text" not in preempt_round[i]: + preempt_round[i]["prefix_stripped_text"] = preempt_round[i].get( + "text", PREFIXES_PREEMPT[i] + )[len(PREFIXES_PREEMPT[i]) :] + + episode = tod.TodStructuredEpisode( + domain=preempt_round[0].get("domain", ""), + api_schemas_machine=tod.SerializationHelpers.str_to_api_schemas( + preempt_round[0].get("prefix_stripped_text", "") + ), + goal_calls_machine=tod.SerializationHelpers.str_to_goals( + preempt_round[3].get("prefix_stripped_text") + ), + rounds=rounds, + ) + return episode + + def setup_episodes(self, fold): + result = [] + if self.opt["tod_metrics_datapath"] is not None: + with open(self.opt["tod_metrics_datapath"]) as f: + report_data = json.load(f) + tod_metrics = report_data["report"]["tod_metrics"] + lines_to_process = [] + with open(self.opt["datafile"], "r") as f: + result = [] + for i, line in enumerate(f.readlines()): + if ( + self.opt["filter_all_goals_hit"] + and tod_metrics[i]["all_goals_hit"] < 0.5 + ): + continue + if line: + lines_to_process.append(line) + + if self.opt["split_to_folds"]: + lines_to_process = DatatypeHelper.split_data_by_fold( + fold, lines_to_process, 0.8, 0.1, 0.1, self.opt["split_folds_seed"] + ) + + for line in lines_to_process: + processed = self._process_line(line) + if processed is not None: + result.append(processed) + return result + + def get_id_task_prefix(self): + return ( + "TodJson_#" + + os.path.basename(self.opt["jsonfile_datapath"]).split(".")[0] + + "#" + ) + + +class SystemTeacher(JsonTodParser, tod_agents.TodSystemTeacher): + pass + + +class DefaultTeacher(SystemTeacher): + pass + + +class UserSimulatorTeacher(JsonTodParser, tod_agents.TodUserSimulatorTeacher): + pass diff --git a/parlai/tasks/tod_json/example_data.jsonl b/parlai/tasks/tod_json/example_data.jsonl new file mode 100644 index 00000000000..8efd5473d82 --- /dev/null +++ b/parlai/tasks/tod_json/example_data.jsonl @@ -0,0 +1,15 @@ +{"dialog": [[{"text": "APIS: api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location", "id": "GoogleSGD_ApiDescription", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 91.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = ShareLocation ; contact_name = Helen ; location = August Hall", "id": "GoogleSGD_Goal", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = ShareLocation ; contact_name = Helen ; location = August Hall"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a contact number for a friend.", "beam_texts": [["USER: I need to find a contact number for a friend.", -4.2508697509765625]], "metrics": {"clen": 24.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a contact number for a friend."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000830456439871341]], "metrics": {"clen": 301.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where would you like to find the contact number?", "beam_texts": [["SYSTEM: Where would you like to find the contact number?", -3.0284223556518555]], "metrics": {"clen": 314.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where would you like to find the contact number?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a contact number for Helen in August Hall.", "beam_texts": [["USER: I need to find a contact number for Helen in August Hall.", -2.9425835609436035]], "metrics": {"clen": 62.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a contact number for Helen in August Hall."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindContact ; contact_name = Helen ; location = August Hall", "beam_texts": [["APICALL: api_name = FindContact ; contact_name = Helen ; location = August Hall", -0.3795860707759857]], "metrics": {"clen": 334.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindContact ; contact_name = Helen ; location = August Hall"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I was unable to find a contact number for Helen. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I was unable to find a contact number for Helen. Is there anything else I can help you with?", -2.7589914798736572]], "metrics": {"clen": 347.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I was unable to find a contact number for Helen. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -1.1491308212280273]], "metrics": {"clen": 97.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0011205199407413602]], "metrics": {"clen": 373.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.7793283462524414]], "metrics": {"clen": 386.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -6.41131464362843e-06]], "metrics": {"clen": 116.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookHouse ; reqArg = check_in_date, check_out_date, number_of_adults, where_to | api_name = SearchHouse ; optArg = has_laundry_service, number_of_adults, rating ; reqArg = where_to | api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date | api_name = FindAttractions ; optArg = category, free_entry, good_for_kids ; reqArg = location", "id": "GoogleSGD_ApiDescription", "domain": "Flights, Hotels, Travel", "episode_done": false, "prefix_stripped_text": "api_name = BookHouse ; reqArg = check_in_date, check_out_date, number_of_adults, where_to | api_name = SearchHouse ; optArg = has_laundry_service, number_of_adults, rating ; reqArg = where_to | api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date | api_name = FindAttractions ; optArg = category, free_entry, good_for_kids ; reqArg = location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -2.1568732222476683e-07]], "metrics": {"clen": 143.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11", "id": "GoogleSGD_Goal", "domain": "Flights, Hotels, Travel", "episode_done": false, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a flight.", "beam_texts": [["USER: I need to find a flight.", -3.2091355323791504]], "metrics": {"clen": 32.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a flight."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0005915741785429418]], "metrics": {"clen": 305.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.660393476486206]], "metrics": {"clen": 318.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm going to Toronto, Canada.", "beam_texts": [["USER: I'm going to Toronto, Canada.", -1.6486107110977173]], "metrics": {"clen": 115.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm going to Toronto, Canada."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0026782331988215446]], "metrics": {"clen": 348.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you leaving?", "beam_texts": [["SYSTEM: When are you leaving?", -2.153712749481201]], "metrics": {"clen": 377.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you leaving?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving on the 2nd of this month.", "beam_texts": [["USER: I'm leaving on the 2nd of this month.", -2.099220037460327]], "metrics": {"clen": 134.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving on the 2nd of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0051135048270225525]], "metrics": {"clen": 410.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -1.2600269317626953]], "metrics": {"clen": 423.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving from Atlanta, GA.", "beam_texts": [["USER: I'm leaving from Atlanta, GA.", -1.2180325984954834]], "metrics": {"clen": 128.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving from Atlanta, GA."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.260057270526886]], "metrics": {"clen": 490.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you returning?", "beam_texts": [["SYSTEM: When are you returning?", -0.7014266848564148]], "metrics": {"clen": 516.0, "ctrunc": 1.0, "ctrunclen": 4.0}, "prefix_stripped_text": "When are you returning?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to return on the 11th of this month.", "beam_texts": [["USER: I want to return on the 11th of this month.", -1.6878958940505981]], "metrics": {"clen": 175.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to return on the 11th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11", "beam_texts": [["APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11", -0.006989812944084406]], "metrics": {"clen": 550.0, "ctrunc": 1.0, "ctrunclen": 38.0}, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11"}, {"text": "APIRESP: airlines = American Airlines ; departure_date = 2019-03-02 ; destination_airport = Toronto ; inbound_arrival_time = 11:17 ; inbound_departure_time = 04:15 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = Atlanta ; outbound_arrival_time = 12:25 ; outbound_departure_time = 05:10 ; price = 332 ; return_date = 2019-03-11 ; seating_class = Economy", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "airlines = American Airlines ; departure_date = 2019-03-02 ; destination_airport = Toronto ; inbound_arrival_time = 11:17 ; inbound_departure_time = 04:15 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = Atlanta ; outbound_arrival_time = 12:25 ; outbound_departure_time = 05:10 ; price = 332 ; return_date = 2019-03-11 ; seating_class = Economy"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 4 flights. There is an American Airlines flight that leaves at 5:10 am and returns at 4:15 am. The cost is $332.", "beam_texts": [["SYSTEM: I found 4 flights. There is an American Airlines flight that leaves at 5:10 am and returns at 4:15 am. The cost is $332.", -2.6480491161346436]], "metrics": {"clen": 507.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 4 flights. There is an American Airlines flight that leaves at 5:10 am and returns at 4:15 am. The cost is $332."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -1.531357765197754]], "metrics": {"clen": 228.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00013564960681833327]], "metrics": {"clen": 419.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to buy tickets?", "beam_texts": [["SYSTEM: Do you want to buy tickets?", -1.0238951444625854]], "metrics": {"clen": 432.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Do you want to buy tickets?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.9948810338974]], "metrics": {"clen": 183.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00037169820279814303]], "metrics": {"clen": 593.0, "ctrunc": 1.0, "ctrunclen": 81.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.9852166771888733]], "metrics": {"clen": 645.0, "ctrunc": 1.0, "ctrunclen": 133.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.6123467683792114]], "metrics": {"clen": 268.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -2.995951217599213e-05]], "metrics": {"clen": 676.0, "ctrunc": 1.0, "ctrunclen": 164.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.6601892709732056]], "metrics": {"clen": 689.0, "ctrunc": 1.0, "ctrunclen": 177.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.747716791873245e-07]], "metrics": {"clen": 368.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date", "id": "GoogleSGD_ApiDescription", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 115.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle", "id": "GoogleSGD_Goal", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a one way flight.", "beam_texts": [["USER: I need a one way flight.", -2.5476834774017334]], "metrics": {"clen": 26.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a one way flight."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0014316064771264791]], "metrics": {"clen": 247.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.7428438663482666]], "metrics": {"clen": 260.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I am going to LAX.", "beam_texts": [["USER: I am going to LAX.", -2.408816337585449]], "metrics": {"clen": 73.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I am going to LAX."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.002551180776208639]], "metrics": {"clen": 284.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you leaving?", "beam_texts": [["SYSTEM: When are you leaving?", -1.995316982269287]], "metrics": {"clen": 332.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you leaving?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I am leaving on the 6th of this month.", "beam_texts": [["USER: I am leaving on the 6th of this month.", -2.0818636417388916]], "metrics": {"clen": 100.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I am leaving on the 6th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.09947514533996582]], "metrics": {"clen": 360.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -0.6350602507591248]], "metrics": {"clen": 416.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I am leaving from Seattle, WA. I prefer American Airlines.", "beam_texts": [["USER: I am leaving from Seattle, WA. I prefer American Airlines.", -1.8830840587615967]], "metrics": {"clen": 100.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I am leaving from Seattle, WA. I prefer American Airlines."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle", "beam_texts": [["APICALL: airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle", -0.004719942342489958]], "metrics": {"clen": 444.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't find any flights matching your criteria. Can I help with something else?", "beam_texts": [["SYSTEM: I couldn't find any flights matching your criteria. Can I help with something else?", -3.5894200801849365]], "metrics": {"clen": 406.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't find any flights matching your criteria. Can I help with something else?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8101848363876343]], "metrics": {"clen": 177.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00011459175584604964]], "metrics": {"clen": 425.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.8406725525856018]], "metrics": {"clen": 500.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.0990861483151093e-06]], "metrics": {"clen": 244.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location", "id": "GoogleSGD_ApiDescription", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 143.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento", "id": "GoogleSGD_Goal", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a bus.", "beam_texts": [["USER: I need to find a bus.", -2.7568917274475098]], "metrics": {"clen": 41.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a bus."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00041706167394295335]], "metrics": {"clen": 197.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.507762908935547]], "metrics": {"clen": 210.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm going to Sacramento, CA.", "beam_texts": [["USER: I'm going to Sacramento, CA.", -1.6594734191894531]], "metrics": {"clen": 62.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm going to Sacramento, CA."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.005608933977782726]], "metrics": {"clen": 229.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When do you want to leave?", "beam_texts": [["SYSTEM: When do you want to leave?", -1.8860359191894531]], "metrics": {"clen": 209.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When do you want to leave?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to leave on the 14th of this month.", "beam_texts": [["USER: I want to leave on the 14th of this month.", -1.5564500093460083]], "metrics": {"clen": 173.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to leave on the 14th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0036852918565273285]], "metrics": {"clen": 237.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -0.7195356488227844]], "metrics": {"clen": 250.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving from San Diego.", "beam_texts": [["USER: I'm leaving from San Diego.", -0.939035177230835]], "metrics": {"clen": 116.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving from San Diego."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento", "beam_texts": [["APICALL: api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento", -0.007453737314790487]], "metrics": {"clen": 274.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento"}, {"text": "APIRESP: class = Value ; date_of_journey = 2019-03-14 ; from = San Diego ; from_station = Santa Fe Depot ; journey_start_time = 06:10 ; number_of_adults = 1 ; to = Sacramento ; to_station = Sacramento Valley Station ; total = 145 ; trip_protection = True", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "class = Value ; date_of_journey = 2019-03-14 ; from = San Diego ; from_station = Santa Fe Depot ; journey_start_time = 06:10 ; number_of_adults = 1 ; to = Sacramento ; to_station = Sacramento Valley Station ; total = 145 ; trip_protection = True"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 buses. There's a 6:10 am bus for $145. It's a value ticket.", "beam_texts": [["SYSTEM: I found 10 buses. There's a 6:10 am bus for $145. It's a value ticket.", -3.985891103744507]], "metrics": {"clen": 354.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 buses. There's a 6:10 am bus for $145. It's a value ticket."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -1.7473149299621582]], "metrics": {"clen": 135.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0016271225176751614]], "metrics": {"clen": 383.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to buy tickets?", "beam_texts": [["SYSTEM: Do you want to buy tickets?", -0.8871648907661438]], "metrics": {"clen": 563.0, "ctrunc": 1.0, "ctrunclen": 51.0}, "prefix_stripped_text": "Do you want to buy tickets?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.4214577674865723]], "metrics": {"clen": 209.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0008082308922894299]], "metrics": {"clen": 550.0, "ctrunc": 1.0, "ctrunclen": 38.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Anything else I can help with?", "beam_texts": [["SYSTEM: Anything else I can help with?", -1.0824320316314697]], "metrics": {"clen": 563.0, "ctrunc": 1.0, "ctrunclen": 51.0}, "prefix_stripped_text": "Anything else I can help with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8818005323410034]], "metrics": {"clen": 243.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0006076796562410891]], "metrics": {"clen": 519.0, "ctrunc": 1.0, "ctrunclen": 7.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.5481600761413574]], "metrics": {"clen": 532.0, "ctrunc": 1.0, "ctrunclen": 20.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.8318112893211946e-07]], "metrics": {"clen": 350.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date", "id": "GoogleSGD_ApiDescription", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 142.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09", "id": "GoogleSGD_Goal", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a flight.", "beam_texts": [["USER: I need to find a flight.", -3.215803623199463]], "metrics": {"clen": 41.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a flight."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0010359908919781446]], "metrics": {"clen": 299.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.9211363792419434]], "metrics": {"clen": 312.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm going to San Francisco.", "beam_texts": [["USER: I'm going to San Francisco.", -2.0261998176574707]], "metrics": {"clen": 101.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm going to San Francisco."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0031771883368492126]], "metrics": {"clen": 331.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you leaving?", "beam_texts": [["SYSTEM: When are you leaving?", -2.7696292400360107]], "metrics": {"clen": 297.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you leaving?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving on the 4th of this month.", "beam_texts": [["USER: I'm leaving on the 4th of this month.", -1.9584730863571167]], "metrics": {"clen": 105.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving on the 4th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.017428752034902573]], "metrics": {"clen": 324.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -1.4117623567581177]], "metrics": {"clen": 373.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving from San Diego.", "beam_texts": [["USER: I'm leaving from San Diego.", -0.9567034244537354]], "metrics": {"clen": 158.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving from San Diego."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.2181233912706375]], "metrics": {"clen": 363.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you returning?", "beam_texts": [["SYSTEM: When are you returning?", -0.7378217577934265]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you returning?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to return on the 9th of this month.", "beam_texts": [["USER: I want to return on the 9th of this month.", -1.6836446523666382]], "metrics": {"clen": 201.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to return on the 9th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09", "beam_texts": [["APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09", -0.0012267155107110739]], "metrics": {"clen": 396.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09"}, {"text": "APIRESP: airlines = Alaska Airlines ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; inbound_arrival_time = 14:05 ; inbound_departure_time = 12:25 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = San Diego ; outbound_arrival_time = 09:10 ; outbound_departure_time = 06:15 ; price = 220 ; return_date = 2019-03-09 ; seating_class = Economy", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "airlines = Alaska Airlines ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; inbound_arrival_time = 14:05 ; inbound_departure_time = 12:25 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = San Diego ; outbound_arrival_time = 09:10 ; outbound_departure_time = 06:15 ; price = 220 ; return_date = 2019-03-09 ; seating_class = Economy"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 4 flights. There is an Alaska Airlines flight that leaves at 6:15 am and returns at 12:25 pm. The cost is $220.", "beam_texts": [["SYSTEM: I found 4 flights. There is an Alaska Airlines flight that leaves at 6:15 am and returns at 12:25 pm. The cost is $220.", -2.7658262252807617]], "metrics": {"clen": 362.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 4 flights. There is an Alaska Airlines flight that leaves at 6:15 am and returns at 12:25 pm. The cost is $220."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -1.5201750993728638]], "metrics": {"clen": 185.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -4.8251091357087716e-05]], "metrics": {"clen": 582.0, "ctrunc": 1.0, "ctrunclen": 70.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to buy tickets?", "beam_texts": [["SYSTEM: Do you want to buy tickets?", -1.2585415840148926]], "metrics": {"clen": 595.0, "ctrunc": 1.0, "ctrunclen": 83.0}, "prefix_stripped_text": "Do you want to buy tickets?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.9538514614105225]], "metrics": {"clen": 246.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00017885112902149558]], "metrics": {"clen": 582.0, "ctrunc": 1.0, "ctrunclen": 70.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.9863885045051575]], "metrics": {"clen": 595.0, "ctrunc": 1.0, "ctrunclen": 83.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.591467559337616]], "metrics": {"clen": 221.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -7.712452315900009e-06]], "metrics": {"clen": 617.0, "ctrunc": 1.0, "ctrunclen": 105.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a good day.", "beam_texts": [["SYSTEM: Have a good day.", -0.7458349466323853]], "metrics": {"clen": 630.0, "ctrunc": 1.0, "ctrunclen": 118.0}, "prefix_stripped_text": "Have a good day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.8318112893211946e-07]], "metrics": {"clen": 285.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date", "id": "GoogleSGD_ApiDescription", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 168.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-12 ; pickup_time = 09:30 ; start_date = 2019-03-11", "id": "GoogleSGD_Goal", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-12 ; pickup_time = 09:30 ; start_date = 2019-03-11"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a rental car.", "beam_texts": [["USER: I need to find a rental car.", -2.235142469406128]], "metrics": {"clen": 48.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a rental car."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00040876015555113554]], "metrics": {"clen": 351.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -3.3327901363372803]], "metrics": {"clen": 171.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it on the 11th of March.", "beam_texts": [["USER: I need it on the 11th of March.", -2.694045066833496]], "metrics": {"clen": 83.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it on the 11th of March."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00027602165937423706]], "metrics": {"clen": 206.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What time do you want to pick it up?", "beam_texts": [["SYSTEM: What time do you want to pick it up?", -1.9576942920684814]], "metrics": {"clen": 219.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What time do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it at 9:30 in the morning.", "beam_texts": [["USER: I need it at 9:30 in the morning.", -2.4926252365112305]], "metrics": {"clen": 120.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it at 9:30 in the morning."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0010227571474388242]], "metrics": {"clen": 434.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where do you want to pick it up?", "beam_texts": [["SYSTEM: Where do you want to pick it up?", -1.4877737760543823]], "metrics": {"clen": 447.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it in Fresno, CA.", "beam_texts": [["USER: I need it in Fresno, CA.", -1.823162317276001]], "metrics": {"clen": 144.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it in Fresno, CA."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-11 ; pickup_time = 09:30", "beam_texts": [["APICALL: api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-11 ; pickup_time = 09:30", -0.2495831549167633]], "metrics": {"clen": 476.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-11 ; pickup_time = 09:30"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", -2.8082520961761475]], "metrics": {"clen": 519.0, "ctrunc": 1.0, "ctrunclen": 7.0}, "prefix_stripped_text": "I couldn't find any cars matching your criteria. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, I'm good.", "beam_texts": [["USER: No, I'm good.", -3.3139195442199707]], "metrics": {"clen": 167.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, I'm good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0005645004566758871]], "metrics": {"clen": 487.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -1.0553348064422607]], "metrics": {"clen": 564.0, "ctrunc": 1.0, "ctrunclen": 52.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.747716791873245e-07]], "metrics": {"clen": 218.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date", "id": "GoogleSGD_ApiDescription", "domain": "Homes", "episode_done": false, "prefix_stripped_text": "api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 59.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02", "id": "GoogleSGD_Goal", "domain": "Homes", "episode_done": false, "prefix_stripped_text": "api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to visit a property.", "beam_texts": [["USER: I want to visit a property.", -3.0068488121032715]], "metrics": {"clen": 22.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to visit a property."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0010161403333768249]], "metrics": {"clen": 142.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When do you want to visit?", "beam_texts": [["SYSTEM: When do you want to visit?", -1.7299853563308716]], "metrics": {"clen": 178.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When do you want to visit?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to visit on the 2nd of this month.", "beam_texts": [["USER: I want to visit on the 2nd of this month.", -2.3218560218811035]], "metrics": {"clen": 144.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to visit on the 2nd of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.020279226824641228]], "metrics": {"clen": 200.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What property do you want to visit?", "beam_texts": [["SYSTEM: What property do you want to visit?", -1.9698238372802734]], "metrics": {"clen": 213.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What property do you want to visit?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to visit Indian Creek.", "beam_texts": [["USER: I want to visit Indian Creek.", -0.8281252980232239]], "metrics": {"clen": 97.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to visit Indian Creek."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0015185625525191426]], "metrics": {"clen": 245.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: You want to visit Indian Creek on March 2nd.", "beam_texts": [["SYSTEM: Please confirm: You want to visit Indian Creek on March 2nd.", -2.028269052505493]], "metrics": {"clen": 287.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Please confirm: You want to visit Indian Creek on March 2nd."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that is correct.", "beam_texts": [["USER: Yes, that is correct.", -1.5401211977005005]], "metrics": {"clen": 203.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that is correct."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02", "beam_texts": [["APICALL: api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02", -0.0001516681513749063]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02"}, {"text": "APIRESP: address = 801 Marine Parkway ; has_garage = False ; in_unit_laundry = True ; number_of_baths = 2 ; number_of_beds = 3 ; phone_number = 650-593-4300 ; price = 3500000 ; property_name = Indian Creek ; visit_date = 2019-03-02", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "address = 801 Marine Parkway ; has_garage = False ; in_unit_laundry = True ; number_of_baths = 2 ; number_of_beds = 3 ; phone_number = 650-593-4300 ; price = 3500000 ; property_name = Indian Creek ; visit_date = 2019-03-02"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Your visit is scheduled.", "beam_texts": [["SYSTEM: Your visit is scheduled.", -0.8813648223876953]], "metrics": {"clen": 369.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Your visit is scheduled."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Thank you.", "beam_texts": [["USER: Thank you.", -2.2147231101989746]], "metrics": {"clen": 136.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -6.387276516761631e-05]], "metrics": {"clen": 392.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.8272175192832947]], "metrics": {"clen": 438.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.5735105276107788]], "metrics": {"clen": 152.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -5.418340879259631e-05]], "metrics": {"clen": 441.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a good day.", "beam_texts": [["SYSTEM: Have a good day.", -0.735770583152771]], "metrics": {"clen": 491.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a good day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.83930285149836e-06]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date", "id": "GoogleSGD_ApiDescription", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 84.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = GetCarsAvailable ; car_type = SUV ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00 ; start_date = 2019-03-05", "id": "GoogleSGD_Goal", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; car_type = SUV ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00 ; start_date = 2019-03-05"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a rental car.", "beam_texts": [["USER: I need to find a rental car.", -2.37575101852417]], "metrics": {"clen": 46.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a rental car."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00040876015555113554]], "metrics": {"clen": 183.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -3.3327901363372803]], "metrics": {"clen": 364.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it on the 5th of this month.", "beam_texts": [["USER: I need it on the 5th of this month.", -3.266505479812622]], "metrics": {"clen": 111.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it on the 5th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00025378118152730167]], "metrics": {"clen": 385.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What time do you want to pick it up?", "beam_texts": [["SYSTEM: What time do you want to pick it up?", -1.9800610542297363]], "metrics": {"clen": 409.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What time do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it at 3 in the afternoon.", "beam_texts": [["USER: I need it at 3 in the afternoon.", -3.0727145671844482]], "metrics": {"clen": 131.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it at 3 in the afternoon."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0008782641962170601]], "metrics": {"clen": 243.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where do you want to pick it up?", "beam_texts": [["SYSTEM: Where do you want to pick it up?", -1.5364357233047485]], "metrics": {"clen": 277.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it in Portland, OR.", "beam_texts": [["USER: I need it in Portland, OR.", -1.7654595375061035]], "metrics": {"clen": 155.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it in Portland, OR."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.3837651014328003]], "metrics": {"clen": 306.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -1.704606294631958]], "metrics": {"clen": 332.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it until the 9th of this month. I want an SUV.", "beam_texts": [["USER: I need it until the 9th of this month. I want an SUV.", -2.499929666519165]], "metrics": {"clen": 181.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it until the 9th of this month. I want an SUV."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = GetCarsAvailable ; car_type = Full-size ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00", "beam_texts": [["APICALL: api_name = GetCarsAvailable ; car_type = Full-size ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00", -0.31725868582725525]], "metrics": {"clen": 551.0, "ctrunc": 1.0, "ctrunclen": 39.0}, "prefix_stripped_text": "api_name = GetCarsAvailable ; car_type = Full-size ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", -3.00557804107666]], "metrics": {"clen": 409.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't find any cars matching your criteria. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8569039106369019]], "metrics": {"clen": 198.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.01578957960009575]], "metrics": {"clen": 550.0, "ctrunc": 1.0, "ctrunclen": 38.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.99875807762146]], "metrics": {"clen": 563.0, "ctrunc": 1.0, "ctrunclen": 51.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.8318112893211946e-07]], "metrics": {"clen": 204.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date", "id": "GoogleSGD_ApiDescription", "domain": "Movies, Services", "episode_done": false, "prefix_stripped_text": "api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 70.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindProvider ; city = Santa Rosa", "id": "GoogleSGD_Goal", "domain": "Movies, Services", "episode_done": false, "prefix_stripped_text": "api_name = FindProvider ; city = Santa Rosa"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a dentist.", "beam_texts": [["USER: I need to find a dentist.", -2.2936532497406006]], "metrics": {"clen": 37.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a dentist."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -4.657026147469878e-05]], "metrics": {"clen": 156.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: In which city?", "beam_texts": [["SYSTEM: In which city?", -1.2308539152145386]], "metrics": {"clen": 169.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "In which city?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a dentist in Santa Rosa.", "beam_texts": [["USER: I need a dentist in Santa Rosa.", -1.99648118019104]], "metrics": {"clen": 94.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a dentist in Santa Rosa."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindProvider ; city = Santa Rosa", "beam_texts": [["APICALL: api_name = FindProvider ; city = Santa Rosa", -0.00010167308937525377]], "metrics": {"clen": 196.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindProvider ; city = Santa Rosa"}, {"text": "APIRESP: address = 4812, 1144 Sonoma Avenue ; city = Santa Rosa ; dentist_name = Access Dental - Santa Rosa ; offers_cosmetic_services = False ; phone_number = 866-682-9904", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "address = 4812, 1144 Sonoma Avenue ; city = Santa Rosa ; dentist_name = Access Dental - Santa Rosa ; offers_cosmetic_services = False ; phone_number = 866-682-9904"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 dentists. How about Access Dental - Santa Rosa in Santa Rosa?", "beam_texts": [["SYSTEM: I found 10 dentists. How about Access Dental - Santa Rosa in Santa Rosa?", -1.664900541305542]], "metrics": {"clen": 242.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 dentists. How about Access Dental - Santa Rosa in Santa Rosa?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that sounds good.", "beam_texts": [["USER: Yes, that sounds good.", -2.585338592529297]], "metrics": {"clen": 115.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00035133821074850857]], "metrics": {"clen": 265.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Would you like to book an appointment?", "beam_texts": [["SYSTEM: Would you like to book an appointment?", -1.2300727367401123]], "metrics": {"clen": 278.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Would you like to book an appointment?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not right now.", "beam_texts": [["USER: No, not right now.", -1.776436448097229]], "metrics": {"clen": 162.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not right now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0005406867130659521]], "metrics": {"clen": 406.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.7443912625312805]], "metrics": {"clen": 311.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8158944845199585]], "metrics": {"clen": 179.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -3.9451406337320805e-05]], "metrics": {"clen": 335.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.9284720420837402]], "metrics": {"clen": 396.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.289761141582858e-06]], "metrics": {"clen": 194.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ReserveHotel ; optArg = number_of_rooms ; reqArg = check_in_date, location, place_name, stay_length | api_name = SearchHotel ; optArg = number_of_rooms, smoking_allowed, star_rating ; reqArg = location", "id": "GoogleSGD_ApiDescription", "domain": "Events, Hotels", "episode_done": false, "prefix_stripped_text": "api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ReserveHotel ; optArg = number_of_rooms ; reqArg = check_in_date, location, place_name, stay_length | api_name = SearchHotel ; optArg = number_of_rooms, smoking_allowed, star_rating ; reqArg = location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 66.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2", "id": "GoogleSGD_Goal", "domain": "Events, Hotels", "episode_done": false, "prefix_stripped_text": "api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to buy tickets to an event.", "beam_texts": [["USER: I want to buy tickets to an event.", -2.4751245975494385]], "metrics": {"clen": 43.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to buy tickets to an event."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -8.394322503590956e-05]], "metrics": {"clen": 147.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What city should I look in?", "beam_texts": [["SYSTEM: What city should I look in?", -3.1570680141448975]], "metrics": {"clen": 160.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What city should I look in?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to find something in New York City.", "beam_texts": [["USER: I want to find something in New York City.", -2.8168575763702393]], "metrics": {"clen": 111.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to find something in New York City."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000529915967490524]], "metrics": {"clen": 186.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date would you like to go?", "beam_texts": [["SYSTEM: What date would you like to go?", -2.9322915077209473]], "metrics": {"clen": 400.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date would you like to go?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to go on the 14th of this month.", "beam_texts": [["USER: I want to go on the 14th of this month.", -1.2178394794464111]], "metrics": {"clen": 138.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to go on the 14th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0003908717480953783]], "metrics": {"clen": 431.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: How many tickets do you want?", "beam_texts": [["SYSTEM: How many tickets do you want?", -1.6059118509292603]], "metrics": {"clen": 444.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "How many tickets do you want?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want two tickets.", "beam_texts": [["USER: I want two tickets.", -1.3766371011734009]], "metrics": {"clen": 165.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want two tickets."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000701539684087038]], "metrics": {"clen": 465.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What event are you interested in?", "beam_texts": [["SYSTEM: What event are you interested in?", -1.3350071907043457]], "metrics": {"clen": 478.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What event are you interested in?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to see Alpha Blondy.", "beam_texts": [["USER: I want to see Alpha Blondy.", -1.0148762464523315]], "metrics": {"clen": 123.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to see Alpha Blondy."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -4.133000038564205e-05]], "metrics": {"clen": 500.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: 2 tickets to Alpha Blondy on March 14th in New York.", "beam_texts": [["SYSTEM: Please confirm: 2 tickets to Alpha Blondy on March 14th in New York.", -1.1921658515930176]], "metrics": {"clen": 405.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Please confirm: 2 tickets to Alpha Blondy on March 14th in New York."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that's great. Where is the venue?", "beam_texts": [["USER: Yes, that's great. Where is the venue?", -2.867255926132202]], "metrics": {"clen": 166.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that's great. Where is the venue?"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2", "beam_texts": [["APICALL: api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2", -1.1100361916760448e-05]], "metrics": {"clen": 401.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Sorry, I couldn't book the tickets. Can I help with anything else?", "beam_texts": [["SYSTEM: Sorry, I couldn't book the tickets. Can I help with anything else?", -2.606656551361084]], "metrics": {"clen": 414.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Sorry, I couldn't book the tickets. Can I help with anything else?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8586272597312927]], "metrics": {"clen": 261.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00046617790940217674]], "metrics": {"clen": 614.0, "ctrunc": 1.0, "ctrunclen": 102.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.6713320016860962]], "metrics": {"clen": 627.0, "ctrunc": 1.0, "ctrunclen": 115.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -3.663622294425295e-07]], "metrics": {"clen": 260.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = ShareLocation ; reqArg = contact_name, location | api_name = GetAlarms | api_name = AddAlarm ; optArg = new_alarm_name ; reqArg = new_alarm_time", "id": "GoogleSGD_ApiDescription", "domain": "Alarm, Messaging, Services", "episode_done": false, "prefix_stripped_text": "api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = ShareLocation ; reqArg = contact_name, location | api_name = GetAlarms | api_name = AddAlarm ; optArg = new_alarm_name ; reqArg = new_alarm_time"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 70.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "id": "GoogleSGD_Goal", "domain": "Alarm, Messaging, Services", "episode_done": false, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a salon.", "beam_texts": [["USER: I need to find a salon.", -2.4835095405578613]], "metrics": {"clen": 34.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a salon."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -7.652808562852442e-05]], "metrics": {"clen": 156.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: In which city?", "beam_texts": [["SYSTEM: In which city?", -1.3473020792007446]], "metrics": {"clen": 169.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "In which city?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a unisex salon in San Jose.", "beam_texts": [["USER: I need a unisex salon in San Jose.", -1.449689507484436]], "metrics": {"clen": 88.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a unisex salon in San Jose."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "beam_texts": [["APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", -3.9753711462253705e-05]], "metrics": {"clen": 197.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}, {"text": "APIRESP: average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", "beam_texts": [["SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", -1.7336264848709106]], "metrics": {"clen": 199.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 salons. Atelier Salon Willow Glen is in San Jose."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -2.342952251434326]], "metrics": {"clen": 117.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00011350425484124571]], "metrics": {"clen": 221.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to book an appointment?", "beam_texts": [["SYSTEM: Do you want to book an appointment?", -1.1533632278442383]], "metrics": {"clen": 234.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Do you want to book an appointment?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.7858071327209473]], "metrics": {"clen": 147.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.001105411909520626]], "metrics": {"clen": 261.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Anything else I can help with?", "beam_texts": [["SYSTEM: Anything else I can help with?", -1.460990309715271]], "metrics": {"clen": 369.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Anything else I can help with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.6900044679641724]], "metrics": {"clen": 164.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0001660985144553706]], "metrics": {"clen": 384.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.5248937010765076]], "metrics": {"clen": 369.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -7.327242883548024e-07]], "metrics": {"clen": 153.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date | api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date | api_name = GetWeather ; optArg = date ; reqArg = city", "id": "GoogleSGD_ApiDescription", "domain": "Homes, RentalCars, Weather", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date | api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date | api_name = GetWeather ; optArg = date ; reqArg = city"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 91.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: add_insurance = True ; api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30 ; start_date = 2019-03-09", "id": "GoogleSGD_Goal", "domain": "Homes, RentalCars, Weather", "episode_done": false, "prefix_stripped_text": "add_insurance = True ; api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30 ; start_date = 2019-03-09"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to reserve a car.", "beam_texts": [["USER: I need to reserve a car.", -2.8472073078155518]], "metrics": {"clen": 20.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to reserve a car."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00011637131683528423]], "metrics": {"clen": 197.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What time do you want to pick it up?", "beam_texts": [["SYSTEM: What time do you want to pick it up?", -3.0318820476531982]], "metrics": {"clen": 210.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What time do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to pick it up at 3:30 pm.", "beam_texts": [["USER: I want to pick it up at 3:30 pm.", -2.287160634994507]], "metrics": {"clen": 68.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to pick it up at 3:30 pm."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0001661973656155169]], "metrics": {"clen": 229.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -2.177884578704834]], "metrics": {"clen": 242.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to pick it up on the 9th of this month.", "beam_texts": [["USER: I want to pick it up on the 9th of this month.", -1.9553515911102295]], "metrics": {"clen": 109.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to pick it up on the 9th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00047072372399270535]], "metrics": {"clen": 265.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where do you want to pick it up?", "beam_texts": [["SYSTEM: Where do you want to pick it up?", -1.259605050086975]], "metrics": {"clen": 278.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to pick it up at PDX International Airport.", "beam_texts": [["USER: I want to pick it up at PDX International Airport.", -0.8431561589241028]], "metrics": {"clen": 126.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to pick it up at PDX International Airport."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.008902691304683685]], "metrics": {"clen": 298.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to return the car?", "beam_texts": [["SYSTEM: What date do you want to return the car?", -1.4374130964279175]], "metrics": {"clen": 316.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to return the car?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to return it on the 14th of this month.", "beam_texts": [["USER: I want to return it on the 14th of this month.", -1.5039756298065186]], "metrics": {"clen": 175.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to return it on the 14th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.023951882496476173]], "metrics": {"clen": 349.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: You want to pick up a standard car from PDX International Airport on March 9th at 3:30 pm. You will return it on March 14th.", "beam_texts": [["SYSTEM: Please confirm: You want to pick up a standard car from PDX International Airport on March 9th at 3:30 pm. You will return it on March 14th.", -2.0667858123779297]], "metrics": {"clen": 397.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Please confirm: You want to pick up a standard car from PDX International Airport on March 9th at 3:30 pm. You will return it on March 14th."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, I want a Hatchback.", "beam_texts": [["USER: No, I want a Hatchback.", -1.684838056564331]], "metrics": {"clen": 226.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, I want a Hatchback."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -2.9663365239684936e-06]], "metrics": {"clen": 738.0, "ctrunc": 1.0, "ctrunclen": 226.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: You want to pick up a Hatchback from PDX International Airport.", "beam_texts": [["SYSTEM: Please confirm: You want to pick up a Hatchback from PDX International Airport.", -0.6164475083351135]], "metrics": {"clen": 751.0, "ctrunc": 1.0, "ctrunclen": 239.0}, "prefix_stripped_text": "Please confirm: You want to pick up a Hatchback from PDX International Airport."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that is correct.", "beam_texts": [["USER: Yes, that is correct.", -0.887704610824585]], "metrics": {"clen": 244.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that is correct."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30", "beam_texts": [["APICALL: api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30", -0.011755380779504776]], "metrics": {"clen": 770.0, "ctrunc": 1.0, "ctrunclen": 258.0}, "prefix_stripped_text": "api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I was unable to reserve the car. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I was unable to reserve the car. Is there anything else I can help you with?", -1.8006048202514648]], "metrics": {"clen": 783.0, "ctrunc": 1.0, "ctrunclen": 271.0}, "prefix_stripped_text": "I was unable to reserve the car. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, that is all.", "beam_texts": [["USER: No, that is all.", -1.2942805290222168]], "metrics": {"clen": 266.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, that is all."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0007116166525520384]], "metrics": {"clen": 805.0, "ctrunc": 1.0, "ctrunclen": 293.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a good day.", "beam_texts": [["SYSTEM: Have a good day.", -0.58097243309021]], "metrics": {"clen": 818.0, "ctrunc": 1.0, "ctrunclen": 306.0}, "prefix_stripped_text": "Have a good day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -5.495433015312301e-07]], "metrics": {"clen": 283.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookAppointment ; reqArg = appointment_date, appointment_time, therapist_name | api_name = FindProvider ; reqArg = city, type | api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city", "id": "GoogleSGD_ApiDescription", "domain": "Services, Services", "episode_done": false, "prefix_stripped_text": "api_name = BookAppointment ; reqArg = appointment_date, appointment_time, therapist_name | api_name = FindProvider ; reqArg = city, type | api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 99.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "id": "GoogleSGD_Goal", "domain": "Services, Services", "episode_done": false, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a salon.", "beam_texts": [["USER: I need to find a salon.", -2.4835095405578613]], "metrics": {"clen": 22.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a salon."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00014603004092350602]], "metrics": {"clen": 213.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: In which city?", "beam_texts": [["SYSTEM: In which city?", -1.5344370603561401]], "metrics": {"clen": 226.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "In which city?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a unisex salon in San Jose.", "beam_texts": [["USER: I need a unisex salon in San Jose.", -1.449689507484436]], "metrics": {"clen": 87.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a unisex salon in San Jose."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "beam_texts": [["APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", -0.0007448627729900181]], "metrics": {"clen": 248.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}, {"text": "APIRESP: average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", "beam_texts": [["SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", -1.8086758852005005]], "metrics": {"clen": 210.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 salons. Atelier Salon Willow Glen is in San Jose."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -2.342952251434326]], "metrics": {"clen": 113.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0001382200134685263]], "metrics": {"clen": 236.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to book an appointment?", "beam_texts": [["SYSTEM: Do you want to book an appointment?", -1.16355562210083]], "metrics": {"clen": 249.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Do you want to book an appointment?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.7858071327209473]], "metrics": {"clen": 137.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000978908152319491]], "metrics": {"clen": 273.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Anything else I can help you with?", "beam_texts": [["SYSTEM: Anything else I can help you with?", -1.4312176704406738]], "metrics": {"clen": 286.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.6700665950775146]], "metrics": {"clen": 177.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -6.713548646075651e-05]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.659913957118988]], "metrics": {"clen": 513.0, "ctrunc": 1.0, "ctrunclen": 1.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -7.327242883548024e-07]], "metrics": {"clen": 287.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = ReserveRestaurant ; optArg = date, number_of_seats ; reqArg = location, restaurant_name, time | api_name = FindRestaurants ; optArg = has_seating_outdoors, has_vegetarian_options, price_range ; reqArg = category, location | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date", "id": "GoogleSGD_ApiDescription", "domain": "Movies, Restaurants", "episode_done": false, "prefix_stripped_text": "api_name = ReserveRestaurant ; optArg = date, number_of_seats ; reqArg = location, restaurant_name, time | api_name = FindRestaurants ; optArg = has_seating_outdoors, has_vegetarian_options, price_range ; reqArg = category, location | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -2.1568732222476683e-07]], "metrics": {"clen": 186.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter Two ; show_date = 2019-03-01", "id": "GoogleSGD_Goal", "domain": "Movies, Restaurants", "episode_done": false, "prefix_stripped_text": "api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter Two ; show_date = 2019-03-01"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch a movie.", "beam_texts": [["USER: I want to watch a movie.", -3.0537710189819336]], "metrics": {"clen": 14.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch a movie."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.002036166610196233]], "metrics": {"clen": 387.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you located?", "beam_texts": [["SYSTEM: Where are you located?", -0.9682686924934387]], "metrics": {"clen": 400.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you located?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch it in Castro Valley.", "beam_texts": [["USER: I want to watch it in Castro Valley.", -1.7995065450668335]], "metrics": {"clen": 105.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch it in Castro Valley."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindMovies ; location = Castro Valley", "beam_texts": [["APICALL: api_name = FindMovies ; location = Castro Valley", -0.000328321271808818]], "metrics": {"clen": 420.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindMovies ; location = Castro Valley"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I can't find any movies matching your criteria.", "beam_texts": [["SYSTEM: I can't find any movies matching your criteria.", -3.635852098464966]], "metrics": {"clen": 433.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I can't find any movies matching your criteria."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch IT Chapter Two.", "beam_texts": [["USER: I want to watch IT Chapter Two.", -2.210944652557373]], "metrics": {"clen": 130.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch IT Chapter Two."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.11667776852846146]], "metrics": {"clen": 456.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to watch it?", "beam_texts": [["SYSTEM: What date do you want to watch it?", -1.8153434991836548]], "metrics": {"clen": 469.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to watch it?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch it on the 1st of this month.", "beam_texts": [["USER: I want to watch it on the 1st of this month.", -1.2194650173187256]], "metrics": {"clen": 156.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch it on the 1st of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter 2 ; show_date = 2019-03-01", "beam_texts": [["APICALL: api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter 2 ; show_date = 2019-03-01", -0.18431638181209564]], "metrics": {"clen": 304.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter 2 ; show_date = 2019-03-01"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I can't find any times for the movie. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I can't find any times for the movie. Is there anything else I can help you with?", -2.2572638988494873]], "metrics": {"clen": 503.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I can't find any times for the movie. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.9540809392929077]], "metrics": {"clen": 146.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000830456439871341]], "metrics": {"clen": 527.0, "ctrunc": 1.0, "ctrunclen": 15.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.8139710426330566]], "metrics": {"clen": 695.0, "ctrunc": 1.0, "ctrunclen": 183.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -3.663622294425295e-07]], "metrics": {"clen": 140.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = PlayMedia ; optArg = album, artist, device ; reqArg = track | api_name = LookupMusic ; optArg = album, artist, genre, year | api_name = GetWeather ; optArg = date ; reqArg = city", "id": "GoogleSGD_ApiDescription", "domain": "Music, Weather", "episode_done": false, "prefix_stripped_text": "api_name = PlayMedia ; optArg = album, artist, device ; reqArg = track | api_name = LookupMusic ; optArg = album, artist, genre, year | api_name = GetWeather ; optArg = date ; reqArg = city"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 75.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: album = Sucker Punch ; api_name = PlayMedia ; device = Kitchen ; track = Strangers", "id": "GoogleSGD_Goal", "domain": "Music, Weather", "episode_done": false, "prefix_stripped_text": "album = Sucker Punch ; api_name = PlayMedia ; device = Kitchen ; track = Strangers"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to listen to the song Strangers from the album Sucker Punch.", "beam_texts": [["USER: I want to listen to the song Strangers from the album Sucker Punch.", -2.6374406814575195]], "metrics": {"clen": 60.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to listen to the song Strangers from the album Sucker Punch."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: album = Sucker Punch ; api_name = PlayMedia ; track = Strangers", "beam_texts": [["APICALL: album = Sucker Punch ; api_name = PlayMedia ; track = Strangers", -0.5099499821662903]], "metrics": {"clen": 165.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "album = Sucker Punch ; api_name = PlayMedia ; track = Strangers"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't play the song. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't play the song. Is there anything else I can help you with?", -2.205378770828247]], "metrics": {"clen": 196.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't play the song. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to play the song on the kitchen device.", "beam_texts": [["USER: I want to play the song on the kitchen device.", -2.291337013244629]], "metrics": {"clen": 46.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to play the song on the kitchen device."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000290650554234162]], "metrics": {"clen": 359.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I want to play Strangers on the kitchen speaker.", "beam_texts": [["SYSTEM: I want to play Strangers on the kitchen speaker.", -2.0472002029418945]], "metrics": {"clen": 372.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to play Strangers on the kitchen speaker."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that is correct.", "beam_texts": [["USER: Yes, that is correct.", -1.4469143152236938]], "metrics": {"clen": 151.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that is correct."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = PlayMedia ; album = Sucker Punch ; device = Kitchen speaker ; track = Strangers", "beam_texts": [["APICALL: api_name = PlayMedia ; album = Sucker Punch ; device = Kitchen speaker ; track = Strangers", -0.029085755348205566]], "metrics": {"clen": 330.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = PlayMedia ; album = Sucker Punch ; device = Kitchen speaker ; track = Strangers"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't play the song. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't play the song. Is there anything else I can help you with?", -0.39661797881126404]], "metrics": {"clen": 343.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't play the song. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.48353710770606995]], "metrics": {"clen": 176.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -8.918321691453457e-05]], "metrics": {"clen": 393.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.6688835620880127]], "metrics": {"clen": 376.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -9.159052751783747e-07]], "metrics": {"clen": 232.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} From 6cb4b865a080fa33e00e440f02a9d4cd8b0ed60c Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 15:50:00 -0800 Subject: [PATCH 04/57] [TOD] Core converesation structure, serialization, const tokens --- parlai/core/tod/tod_core.py | 227 ++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 parlai/core/tod/tod_core.py diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py new file mode 100644 index 00000000000..76ee8005a74 --- /dev/null +++ b/parlai/core/tod/tod_core.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) enums and base classes. + +This file defines standard tokens, classes for round and conversation structure, and a serialization class to aid in converting between these. + +See `tod_agents.py` for usage of these classes to generate training data and `tod_world_script.py` for usage of these classes in simulated conversations. +""" +from enum import Enum +from typing import List, Dict +from dataclasses import dataclass, field +from collections.abc import Iterable +from parlai.utils.misc import warn_once + +STANDARD_CALL = "APICALL: " +STANDARD_RESP = "APIRESP: " +STANDARD_SYSTEM_UTTERANCE = "SYSTEM: " +STANDARD_USER_UTTERANCE = "USER: " + +STANDARD_GOAL = "GOAL: " +STANDARD_API_SCHEMAS = "APIS: " + +STANDARD_API_NAME_SLOT = "api_name" +STANDARD_REQUIRED_KEY = "reqArg" +STANDARD_OPTIONAL_KEY = "optArg" +STANDARD_DONE = "[DONE]" + +CONST_SILENCE = "__SILENCE__" + + +class TodAgentType(str, Enum): + USER_UTT_AGENT = "user_utt_model" + API_CALL_AGENT = "api_call_model" + API_RESP_AGENT = "api_resp_model" + SYSTEM_UTT_AGENT = "system_utt_model" + API_SCHEMA_GROUNDING_AGENT = "api_schema_grounding_model" + GOAL_GROUNDING_AGENT = "goal_grounding_model" + + +TOD_AGENT_TYPE_TO_PREFIX = { + TodAgentType.USER_UTT_AGENT: STANDARD_USER_UTTERANCE, + TodAgentType.API_CALL_AGENT: STANDARD_CALL, + TodAgentType.API_RESP_AGENT: STANDARD_RESP, + TodAgentType.SYSTEM_UTT_AGENT: STANDARD_SYSTEM_UTTERANCE, + TodAgentType.API_SCHEMA_GROUNDING_AGENT: STANDARD_API_SCHEMAS, + TodAgentType.GOAL_GROUNDING_AGENT: STANDARD_GOAL, +} + + +@dataclass +class TodStructuredRound: + """ + Dataclass for rounds. + """ + + # Variables set by those using this class + user_utt: str = "" + api_call_machine: Dict = field( + default_factory=dict + ) # Hashmap of slot keys and slot values. Note that STANDARD_API_NAME_SLOT (`api_name`) is expected to be one of the keys here when this is nonempty; simulation metrics wonky without + api_resp_machine: Dict = field(default_factory=dict) + sys_utt: str = "" + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_call_utt: str = field(init=False) + api_resp_utt: str = field(init=False) + + def __post_init__(self): + self.api_call_utt = SerializationHelpers.api_dict_to_str(self.api_call_machine) + self.api_resp_utt = SerializationHelpers.api_dict_to_str(self.api_resp_machine) + if ( + len(self.api_call_machine) > 0 + and STANDARD_API_NAME_SLOT not in self.api_call_machine + ): + warn_once( + f"{STANDARD_API_NAME_SLOT} missing when API Call present. This may cause issues for simulation metrics." + ) + + +@dataclass +class TodStructuredEpisode: + """ + Dataclass for episode-level data. + """ + + # Variables set by those using this class + delex: bool = False # Set to true and this class will handle delexicalizing call + response utterances based on API calls and responses exposed to this class. + domain: str = "" # self-explanatory + api_schemas_machine: List[Dict[str, List]] = field( + default_factory=list + ) # Expected to be a List of Dicts with the API name, required arguments, and optional arguments (specified by consts at the top of this file) as keys + goal_calls_machine: List[Dict[str, str]] = field( + default_factory=list + ) # Machine-formatted API calls + rounds: List[TodStructuredRound] = field(default_factory=list) # self explanatory + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_schemas_utt: str = field(init=False) + goal_calls_utt: str = field(init=False) + + def __post_init__(self): + self.api_schemas_utt = SerializationHelpers.list_of_maps_to_str( + self.api_schemas_machine + ) + self.goal_calls_machine = [ + call for call in self.goal_calls_machine if len(call) > 0 + ] + self.goal_calls_utt = SerializationHelpers.list_of_maps_to_str( + self.goal_calls_machine + ) + # Add a done turn at the end + self.rounds.append(TodStructuredRound(user_utt=STANDARD_DONE)) + if self.delex: + accum_slots = ( + {} + ) # separate since some slot values change as we go. Use this for delex first + cum_slots = self.get_all_slots() + for r in self.rounds: + accum_slots.update(r.api_call_machine) + accum_slots.update(r.api_resp_machine) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, accum_slots) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, cum_slots) + + def get_all_slots(self): + result = {} + for r in self.rounds: + result.update(r.api_call_machine) + result.update(r.api_resp_machine) + return result + + +class SerializationHelpers: + @classmethod + def delex(cls, text, slots): + delex = text + for slot, value in slots.items(): + if isinstance(value, str): + delex = delex.replace(value, f"[{slot}]") + else: + for v in value: + delex = delex.replace(v, f"[{slot}]") + return delex + + @classmethod + def inner_list_join(cls, values): + if isinstance(values, str): + return values + return ", ".join(sorted([v.strip() for v in values])) + + @classmethod + def inner_list_split(cls, s): + return s.split(", ") + + @classmethod + def maybe_inner_list_join(cls, values): + if isinstance(values, str) or isinstance(values, int): + return values + elif isinstance(values, Iterable): + return SerializationHelpers.inner_list_join(values) + else: + raise RuntimeError("invalid type of argument for maybe_inner_list_join") + + @classmethod + def api_dict_to_str(cls, apidict): + """ + Used for API Calls and Responses -> Utterance. + """ + return " ; ".join( + f"{k} = {SerializationHelpers.maybe_inner_list_join(v)}" + for k, v in sorted(apidict.items()) + ) + + @classmethod + def str_to_api_dict(cls, string): + """ + Used for API Call and Response Utterances -> Dict. + """ + slot_strs = string.split(" ; ") + result = {} + for slot_str in slot_strs: + if " = " not in slot_str: + continue + name, value = slot_str.split(" = ", 1) + name = name.strip() + value = value.strip() + result[name] = value + return result + + @classmethod + def outer_list_join(cls, s): + return " | ".join(s) + + @classmethod + def outer_list_split(cls, s): + return s.split(" | ") + + @classmethod + def str_to_list_of_maps(cls, s): + return [ + SerializationHelpers.str_to_api_dict(x) + for x in SerializationHelpers.outer_list_split(s) + ] + + @classmethod + def list_of_maps_to_str(cls, list_of_maps): + return SerializationHelpers.outer_list_join( + [SerializationHelpers.api_dict_to_str(m) for m in list_of_maps] + ) + + @classmethod + def str_to_goals(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) + + @classmethod + def str_to_api_schemas(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) From 1480defd9fe6dacfeac287cb64c7b1e3436a236f Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 10:08:40 -0800 Subject: [PATCH 05/57] fix test by adding init folder --- parlai/core/tod/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 parlai/core/tod/__init__.py diff --git a/parlai/core/tod/__init__.py b/parlai/core/tod/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/core/tod/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. From de84801e3b47d9555277b03552835f75f2c529cd Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 20:21:58 -0800 Subject: [PATCH 06/57] [Tod] Agents, teacher metrics, and tests for these See documentation block in `tod_agents.py` (I'm not 100% sure if `conftest.py` is a right file to change, though I did notice that `pytest.ini` was necessary to get pytest to run.) --- conftest.py | 1 + parlai/core/tod/teacher_metrics.py | 159 ++++ parlai/core/tod/tod_agents.py | 794 ++++++++++++++++++ parlai/core/tod/tod_test_utils/__init__.py | 5 + parlai/core/tod/tod_test_utils/test_agents.py | 216 +++++ pytest.ini | 1 + tests/tod/__init__.py | 5 + tests/tod/test_tod_agents_and_teachers.py | 327 ++++++++ tests/tod/test_tod_teacher_metrics.py | 74 ++ 9 files changed, 1582 insertions(+) create mode 100644 parlai/core/tod/teacher_metrics.py create mode 100644 parlai/core/tod/tod_agents.py create mode 100644 parlai/core/tod/tod_test_utils/__init__.py create mode 100644 parlai/core/tod/tod_test_utils/test_agents.py create mode 100644 tests/tod/__init__.py create mode 100644 tests/tod/test_tod_agents_and_teachers.py create mode 100644 tests/tod/test_tod_teacher_metrics.py diff --git a/conftest.py b/conftest.py index 7cc1e262461..8970273a460 100644 --- a/conftest.py +++ b/conftest.py @@ -67,6 +67,7 @@ def filter_tests_with_circleci(test_list): ('datatests/', 'data'), ('parlai/tasks/', 'teacher'), ('tasks/', 'tasks'), + ('tod/', 'tod'), ] diff --git a/parlai/core/tod/teacher_metrics.py b/parlai/core/tod/teacher_metrics.py new file mode 100644 index 00000000000..3fc85c7d107 --- /dev/null +++ b/parlai/core/tod/teacher_metrics.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) teacher metrics. +""" +from typing import Optional, List, Dict, Any +from parlai.core.metrics import AverageMetric, BleuMetric, F1Metric, Metric, Metrics + + +class SlotMetrics(Metrics): + """ + Helper container which encapsulates standard slot metrics in task oriented learning + (jga, slot_p, slot_r, etc). + + Due to differences in dialogue representations between tasks, the input is pre- + parsed ground truth and predicted slot dictionaries. + + The 'jga+nlg' metric assumes a balanced set of JGA and NLG scores such that + 2 * Avg(JGA, NLG_BLEU) = Avg(JGA + NLG_BLEU) + The `jga+nlg` metric assumes that `NlgMetrics` is used to calculated the other side. + """ + + def __init__( + self, + teacher_slots: Dict[str, str], + predicted_slots: Dict[str, str], + prefixes: Optional[List] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + # jga and optionally Avg(jga,nlg_bleu) + self.add_with_prefixes("jga", AverageMetric(teacher_slots == predicted_slots)) + if len(teacher_slots) > 0: + self.add_with_prefixes( + "jga_noempty", AverageMetric(teacher_slots == predicted_slots) + ) + else: + self.add_with_prefixes( + "jga_empty", AverageMetric(teacher_slots == predicted_slots) + ) + + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), NlgMetrics class (below) adds NLG-BLEU + self.add("jga+nlg", AverageMetric(teacher_slots == predicted_slots)) + # precision + for pred_slot_name, pred_value in predicted_slots.items(): + slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value) + self.add_with_prefixes("slot_p", slot_p) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_p=slot_p)) + # recall + for teacher_slot_name, teacher_value in teacher_slots.items(): + slot_r = AverageMetric( + predicted_slots.get(teacher_slot_name) == teacher_value + ) + self.add_with_prefixes("slot_r", slot_r) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_r=slot_r)) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +class NlgMetrics(Metrics): + """ + Helper container for generation version of standard metrics (F1, BLEU, ..). + """ + + def __init__( + self, + guess: str, + labels: Optional[List[str]], + prefixes: Optional[List[str]] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + bleu = BleuMetric.compute(guess, labels) + f1 = F1Metric.compute(guess, labels) + self.add_with_prefixes("nlg_bleu", bleu) + self.add_with_prefixes("nlg_f1", f1) + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), SlotMetrics class (above) adds JGA + self.add("jga+nlg", bleu) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +AverageType = Optional[AverageMetric] + + +def _average_type_sum_helper(first: AverageType, second: AverageType) -> AverageType: + """ + Helper to deal with Nones. + + We are "clever" in how we aggregate SlotF1Metrics (See SlotMetrics `__init__`) in + that we add precision and recall values separately, but this means we need to handle + None. + """ + if first is None: + return second + if second is None: + return first + return first + second + + +class SlotF1Metric(Metric): + """ + Metric to keep track of slot F1. + + Keeps track of slot precision and slot recall as running metrics. + """ + + __slots__ = ("_slot_p", "_slot_r") + + @property + def macro_average(self) -> bool: + """ + Indicates whether this metric should be macro-averaged when globally reported. + """ + return True + + def __init__(self, slot_p: AverageType = None, slot_r: AverageType = None): + if not isinstance(slot_p, AverageMetric) and slot_p is not None: + slot_p = AverageMetric(slot_p) + if not isinstance(slot_r, AverageMetric) and slot_r is not None: + slot_r = AverageMetric(slot_r) + self._slot_p = slot_p + self._slot_r = slot_r + + def __add__(self, other: Optional["SlotF1Metric"]) -> "SlotF1Metric": + # NOTE: hinting can be cleaned up with "from __future__ import annotations" when + # we drop Python 3.6 + if other is None: + return self + slot_p = _average_type_sum_helper(self._slot_p, other._slot_p) + slot_r = _average_type_sum_helper(self._slot_r, other._slot_r) + return type(self)(slot_p=slot_p, slot_r=slot_r) + + def value(self) -> float: + if self._slot_p is None or self._slot_r is None: + return float("nan") + else: + slot_p = self._slot_p.value() + slot_r = self._slot_r.value() + if slot_p == 0.0 and slot_r == 0.0: + return float("nan") + else: + return 2 * (slot_p * slot_r) / (slot_p + slot_r) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py new file mode 100644 index 00000000000..f22a8330760 --- /dev/null +++ b/parlai/core/tod/tod_agents.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Agents (used for dumping data) and Teachers (for training models) related to the TOD +conversation setup. + +# Usage + +For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. + +For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. + +Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for usage in generating simulations. + +As a convention, agents and teachers that are inheritable are prefixed with "Tod" whereas those that can be used as-is are not. Similarly, classes and functions that do not need to be exposed outside of this file are prefixed with a single underscore ('_'). + +## Why we do this +These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. + +# Description of Agents + Teachers useful for Simulation +## Teachers for training (generative) models + * TodSystemTeacher + * TodUserSimulatorTeacher + +## Agents for Grounding +For goal grounding for the User for simulation: + * TodGoalAgent + * TodSingleGoalAgent + +For (optional) API schema grounding for the System: + * TodApiSchemaAgent (must be used with `TodGoalAgent` only) + * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) + * EmptyApiSchemaAgent + * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. + +## Agents for mocking APIs: + * StandaloneApiAgent + * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher` + +# Agents for dumping data from a ground truth dataset +The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. + * TodApiCallAndSysUttAgent + * TodApiResponseAgent + * TodUserUttAgent + +For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. + +# Other agents +There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. +""" + +from parlai.core.agents import Agent +from parlai.core.message import Message +from parlai.core.metrics import AverageMetric +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.core.teachers import DialogTeacher +from parlai.utils.distributed import is_distributed, get_rank, num_workers + +import parlai.core.tod.tod_core as tod +from parlai.core.tod.tod_core import SerializationHelpers +from parlai.core.tod.teacher_metrics import SlotMetrics, NlgMetrics + +from typing import Optional, List +import json +import pickle +import difflib +import random +from math import ceil + + +######### Agents that dump information from a dataset; base classes +class TodStructuredDataParser(Agent): + """ + Base class that specifies intermediate representations for Tod conversations. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + if hasattr(super(), "add_cmdline_args"): + parser = super().add_cmdline_args(parser, partial_opt) + group = parser.add_argument_group("TOD StructuredData agent") + group.add_argument( + "--episodes-randomization-seed", + type=int, + default=-1, + help="Randomize episodes in a predictable way (eg, for few shot). Set to -1 for no randomization. ", + ) + parser.add_argument( + "--n-shot", + default=-1, + type=int, + help="Number of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--percent-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + parser.add_argument( + "--percent-shot", + default=-1, + type=float, + help="Percentage of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--n-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + return parser + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.id = self.get_id_task_prefix() + "_" + self._get_agent_type_suffix() + if shared is None: + self.episodes = self.generate_episodes() + else: + self.episodes = shared["episodes"] + + def share(self): + share = super().share() + share["episodes"] = self.episodes + return share + + def setup_episodes(self, fold: str) -> List[tod.TodStructuredEpisode]: + """ + Fold here is a data fold. + """ + raise NotImplementedError( + "Must have method for generating an episode. Must be set in downstream Parser for a given task" + ) + + def generate_episodes(self) -> List[tod.TodStructuredEpisode]: + if self.opt.get("n_shot", -1) >= 0 and self.opt.get("percent_shot", -1) >= 0: + # Validate before spending a while to load eeverything + raise RuntimeError("Both `--n-shot` and `--percent-shot` in use!") + episodes = list(self.setup_episodes(self.fold)) + if self.opt.get("episodes_randomization_seed", -1) != -1: + random.Random(self.opt["episodes_randomization_seed"]).shuffle(episodes) + if self.opt.get("n_shot", -1) != -1: + episodes = episodes[: self.opt["n_shot"]] + elif self.opt.get("percent_shot", -1) >= 0: + episodes = episodes[: int(len(episodes) * self.opt["percent_shot"])] + return episodes + + def get_id_task_prefix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set ID prefix in downstream task agent. Must be set in downsream Parser for a given task" + ) + + def _get_agent_type_suffix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set in downstream agent within `tod_agents`. If you see this error, something is wrong with TOD Infrastructure" + ) + + +######### Agents that dump information from a dataset as gold (explicitly should *not* be used with teachers) +class _TodDataDumpAgent(TodStructuredDataParser): + """ + For agents which dump data from some dataset, without training/other modifications. + + Implements an "epoch done" + + Member variables assumed to be set in init downstream: + self.fold + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + self.batchsize = opt.get("batchsize", 1) + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + self.episode_idx = opt.get("batchindex", 0) + self._setup_next_episode() + self.round_idx = 0 # for some downstream utt + sysUttAndApiCallAgents. + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.episode_idx += rank * chunk_size + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + def _setup_next_episode(self): + self.epochDone = not self.episode_idx < self.max_episodes + self.episode = None + if not self.epochDone: + self.episode = self.episodes[self.episode_idx] + self.round_idx = ( + 0 # so downstream agents know which round they are in. Update in `act()` + ) + + def epoch_done(self) -> bool: + return self.epochDone + + def episode_done(self) -> bool: + """ + This is not actually "episode_done" so much as "we want to signify to the world + that we have gone past the batch". + + This class should not control whether or not the episode is actually done since + the TodWorld expects that to come from the User agent. + """ + return self.epochDone + + def num_episodes(self) -> int: + return len(self.episodes) + + def reset(self): + self.episode_idx += self.batchsize + self._setup_next_episode() + + +class TodGoalAgent(_TodDataDumpAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + """ + + def act(self): + return { + "text": f"{tod.STANDARD_GOAL}{self.episode.goal_calls_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "Goal" + + +class TodApiSchemaAgent(_TodDataDumpAgent): + def act(self): + return { + "text": f"{tod.STANDARD_API_SCHEMAS}{self.episode.api_schemas_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "ApiSchema" + + +############# Single Goal + Api Schema Agent +class _EpisodeToSingleGoalProcessor(_TodDataDumpAgent): + """ + Iterate through all of the goals of a dataset, one by one. + + Slightly different logic than the dump agent since how we count + setup examples for + an episode are different + + Used as a mixin in the SingleGoal and SingleApiSchema agents below. + + This class exposes a `filter_goals()` function that can be overridden by downstream agents. + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + if shared is None: + self.episodes = self._setup_single_goal_episodes() + else: + # Handled fine in _TodDataDumpAgent + pass + + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + self._setup_next_episode() + + def _setup_single_goal_episodes(self) -> List[tod.TodStructuredEpisode]: + """ + This function assumes that `self.setup_episodes()` has already been called + prior. + + Based on the `__init__` order of this class, it should be done in + `TodStructuredDataParser` by this point. + """ + raw_episodes = self.episodes + result = [] + for raw in raw_episodes: + for call in self.filter_goals(raw.goal_calls_machine): + schema = {} + for cand in raw.api_schemas_machine: + if ( + cand[tod.STANDARD_API_NAME_SLOT] + == call[tod.STANDARD_API_NAME_SLOT] + ): + schema = cand + + result.append( + tod.TodStructuredEpisode( + domain=raw.domain, + api_schemas_machine=[schema], + goal_calls_machine=[call], + rounds=[], + ) + ) + return result + + def filter_goals(self, goals): + """ + Some downstream agents may want to filter the goals. + + Override this if so. + """ + return goals + + +class TodSingleGoalAgent(_EpisodeToSingleGoalProcessor, TodGoalAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: If an API schema agent is used, this *must* be used with `TodSingleApiSchemaAgent` since it will be nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleGoal" + + +class TodSingleApiSchemaAgent(_EpisodeToSingleGoalProcessor, TodApiSchemaAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: Must be used with TodSingleGoalAgent since nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleApiSchema" + + +###### Agents used for calculating TOD World Metrics based on a dataset. See `tod_world_script` or `parlai/projects/tod_simulator/` for examples. +class TodUserUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "User" agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_USER_UTTERANCE}{self.episode.rounds[self.round_idx].user_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "User" + + +class TodApiCallAndSysUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "System" + agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def __init__(self, opt: Opt, shared=None): + # This class represents two "agents" so need to make sure we don't increment episode number (reset) twice + self.already_reset = False + self.api_call_turn = True + super().__init__(opt, shared) + + def act(self): + self.already_reset = False + if tod.STANDARD_API_SCHEMAS in self.observation.get("text", ""): + return { + "text": tod.STANDARD_API_SCHEMAS, + "id": self.id, + "domain": self.episode.domain, + "episode_down": False, + } + + if self.api_call_turn: # comes first, don't iterate round # + result = { + "text": f"{tod.STANDARD_CALL}{self.episode.rounds[self.round_idx].api_call_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + else: + result = { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{self.episode.rounds[self.round_idx].sys_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + + self.api_call_turn ^= True + return result + + def reset(self): + if not self.already_reset: + super().reset() # setup next episode + self.api_call_turn = True + self.already_reset = True + + def _get_agent_type_suffix(self): + return "System" + + +class TodApiResponseAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the API + Simulator. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_RESP}{self.episode.rounds[self.round_idx].api_resp_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "ApiResponse" + + +###### Standalone API agent +class StandaloneApiAgent(Agent): + """ + Trainable agent that saves API calls and responses. + + Use `TodStandaloneApiTeacher` to train this class. For example for a MultiWoz V2.2 + standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the + `.pickle` file to use. + """ + + EMPTY_RESP = { + "text": tod.STANDARD_RESP, + "id": "StandaloneApiAgent", + "episode_done": False, + } + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + group = parser.add_argument_group("TOD Standalone API args") + group.add_argument( + "--exact-api-call", + type=bool, + default=True, + help="Validation-time flag. If true, will return '' if exact api call values not found. If false, will pick response from the same intent with similar api parameters (assuming intent is the same when available)", + ) + + group.add_argument( + "--fail-hard", + type=bool, + default=False, + help="Aids in deugging. Will throw exception if API call not found and '--exact-api-call' is set.", + ) + + group.add_argument( + "--standalone-api-file", + type=str, + default=None, + help="Path to file holding `.pickle` of standalone api for validation (will intelligently strip if suffix included). If not set, assumes the `model_file` argument will contain the `.pickle` file. ", + ) + return parser + + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self.id = "StandaloneApiAgent" + file_key = "model_file" + if self.opt["standalone_api_file"] is not None: + file_key = "standalone_api_file" + self.path_base = self.opt[file_key].replace(".pickle", "") + self.db_path = self.path_base + ".pickle" + self.exact_api_call = self.opt["exact_api_call"] + try: + with (open(self.db_path, "rb")) as openfile: + self.data = pickle.load(openfile) + self.training = True + print("Loaded Standalone API data successfully") + if self.exact_api_call != self.data.get("exact_api_call", True): + raise RuntimeError( + f"Standalone API .pickle file generated with `exact_api_call` of {self.data.get('exact_api_call', False)} but StandaloneApiAgent sets it to {self.exact_api_call}" + ) + except Exception: + print(f"No file at {self.db_path}; ASSUMING WE ARE TRAINING") + self.data = {} + self.data["exact_api_call"] = self.exact_api_call + self.training = True + + def _maybe_filter_prefix(self, text, prefix): + if prefix in text: + return text[len(prefix) :].strip() + return text.strip() + + def act(self): + if not self.observation["text"].startswith(tod.STANDARD_CALL): + return self.EMPTY_RESP + call_text_raw = self.observation["text"] + # decode then reencode the API call so that we get the API calls in a consistent order + call_text = SerializationHelpers.api_dict_to_str( + SerializationHelpers.str_to_api_dict( + call_text_raw[len(tod.STANDARD_CALL) :] + ) + ) + if "labels" in self.observation: + return self._do_train(call_text) + return self._do_fetch(call_text) + + def _do_train(self, call_text): + assert self.training is True + self.data[call_text] = self.observation["labels"][0] + return self.EMPTY_RESP + + def _do_fetch(self, call_text): + if self.exact_api_call: + if self.opt.get("fail_hard", False): + resp = self.data[call_text] + else: + resp = self.data.get(call_text, tod.STANDARD_RESP) + return { + "text": resp, + "id": self.id, + "episode_done": False, + } + + # Not exact case + best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) + if len(best_key) == 0: + return self.EMPTY_RESP + return { + "text": self.data.get(best_key[0], tod.STANDARD_RESP), + "id": self.id, + "episode_done": False, + } + + def shutdown(self): + if self.training: + with (open(self.db_path, "wb")) as openfile: + pickle.dump(self.data, openfile) + print(f"Dumped output to {self.db_path}") + with open(self.path_base + ".opt", "w") as f: + json.dump(self.opt, f) + + +######### Empty agents +class EmptyApiSchemaAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyApiSchemaAgent" + + def act(self): + msg = { + "id": self.getID(), + "text": tod.STANDARD_API_SCHEMAS, + "episode_done": False, + } + return Message(msg) + + +class EmptyGoalAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyGoalAgent" + + def act(self): + msg = {"id": self.getID(), "text": tod.STANDARD_GOAL, "episode_done": False} + return Message(msg) + + +############# Teachers +class TodSystemTeacher(TodStructuredDataParser, DialogTeacher): + """ + TOD agent teacher which produces both API calls and NLG responses. + + First turn is API Schema grounding, which may be a an empty schema. + Subsequent turns alternate between + 1. User utterance -> API Call + 2. API Response -> System Utterance + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--api-schemas", + type="bool", + default=False, + help="Preempt first turn with intents + required/optional parameters as key/value for given domain", + ) + parser.add_argument( + "--api-jga-record", + type=bool, + default=True, + help="Should we save jga information per api schema?", + ) + parser.add_argument( + "--domain-jga-record", + type=bool, + default=False, + help="Should we save jga information per domain?", + ) + parser.add_argument( + "--domain-nlg-record", + type=bool, + default=False, + help="Should we save nlg information per domain?", + ) + return parser + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_CALL: + if resp.startswith(tod.STANDARD_CALL): + resp = resp[len(tod.STANDARD_CALL) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + domains = ( + [teacher_action["domain"]] if self.opt["domain_jga_record"] else [] + ) + + metrics = SlotMetrics( + teacher_slots=teacher_action["slots"], + predicted_slots=predicted, + avg_jga_nlg_bleu=True, + prefixes=domains, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + if self.opt["api_jga_record"] and len(teacher_action["slots"]) > 0: + teacher = teacher_action["slots"] + slots = list(teacher.keys()) + slots.remove(tod.STANDARD_API_NAME_SLOT) + api_here = ( + "api-" + + teacher[tod.STANDARD_API_NAME_SLOT] + + "--" + + "-".join(slots) + ) + self.metrics.add(f"{api_here}/jga", AverageMetric(teacher == predicted)) + + elif teacher_action["type"] == tod.STANDARD_SYSTEM_UTTERANCE: + domains = ( + [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] + ) + metrics = NlgMetrics( + guess=resp, + labels=labels, + prefixes=domains, + avg_jga_nlg_bleu=True, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if self.opt.get("api_schemas"): + schemas = episode.api_schemas_utt + else: + schemas = "" + yield { + "text": f"{tod.STANDARD_API_SCHEMAS}{schemas}", + "label": f"{tod.STANDARD_API_SCHEMAS}", + "domain": episode.domain, + "type": tod.STANDARD_API_SCHEMAS, + "slots": {}, + }, True + for r in episode.rounds: + yield { + "text": f"{tod.STANDARD_USER_UTTERANCE}{r.user_utt}", + "label": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "domain": episode.domain, + "type": tod.STANDARD_CALL, + "slots": r.api_call_machine, + }, False + yield { + "text": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "label": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "domain": episode.domain, + "slots": r.api_resp_machine, + "type": tod.STANDARD_SYSTEM_UTTERANCE, + }, False + + def _get_agent_type_suffix(self): + return "SystemTeacher" + + +class TodUserSimulatorTeacher(TodStructuredDataParser, DialogTeacher): + """ + Teacher that has `Goal->User Utterance` for its first turn, then `System + Utterance->User Utterance` for all subsequent turns. + """ + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if len(episode.rounds) < 1: + continue + yield { + "text": f"{tod.STANDARD_GOAL}{episode.goal_calls_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[0].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + }, True + for i, r in enumerate(episode.rounds): + if i == len(episode.rounds) - 1: + continue + yield { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[i+1].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + "slots": {}, # slots in agent/user turns are meaningless + }, False + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_RESP: + if resp.startswith(tod.STANDARD_RESP): + resp = resp[len(tod.STANDARD_RESP) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + + metrics = SlotMetrics(teacher_action["slots"], predicted).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + elif teacher_action["type"] == tod.STANDARD_USER_UTTERANCE: + metrics = NlgMetrics(resp, labels).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def _get_agent_type_suffix(self): + return "UserSimulatorTeacher" + + +class TodStandaloneApiTeacher(TodStructuredDataParser, DialogTeacher): + """ + Use this to generate a database for `StandaloneApiAgent`. + + Set this as the teacher with `StandaloneApiAgent` as the agent. Ex for a MultiWoz + V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` + """ + + def setup_data(self, fold): + # As a default, just put everything in + for fold_overwrite in ["train", "valid", "test"]: + for episode in self.setup_episodes(fold_overwrite): + first = True + for r in episode.rounds: + if len(r.api_call_machine) > 0: + yield { + "text": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "label": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "id": self.id, + "domain": episode.domain, + }, first + first = False + + def _get_agent_type_suffix(self): + return "StandaloneApiTeacher" diff --git a/parlai/core/tod/tod_test_utils/__init__.py b/parlai/core/tod/tod_test_utils/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/core/tod/tod_test_utils/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py new file mode 100644 index 00000000000..b1339052764 --- /dev/null +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helpers so we don't need to create agents all over. +""" + +import parlai.core.tod.tod_agents as tod_agents +import parlai.core.tod.tod_core as tod_core + +import os + +API_DATABASE_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "standalone_api_file.pickle" +) + + +def episode_has_broken_api_turn(episode_idx, max_turns): + return episode_idx % 2 == 1 and max_turns > 0 + + +def use_broken_api_calls_this_turn(round_idx, episode_idx): + return episode_idx % 2 == 1 and round_idx % 3 == 1 + + +def make_api_call_machine(round_idx, episode_idx=0, use_broken_mock_api_calls=False): + if round_idx == 0: + return {} + if use_broken_mock_api_calls: + # Hack as a way to test metrics reporting in tod world script + if use_broken_api_calls_this_turn(round_idx, episode_idx): + round_idx = -1 * round_idx + return {tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", "in": round_idx} + + +def make_api_resp_machine(round_idx): + if round_idx == 0: + return {} + return {"out": round_idx} + + +def make_api_schemas_machine(max_rounds): + return [ + { + tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", + tod_core.STANDARD_REQUIRED_KEY: ["in"], + tod_core.STANDARD_OPTIONAL_KEY: [], + } + for round_idx in range(1, max_rounds) + ] + + +def make_goal_calls_machine(max_rounds): + return [make_api_call_machine(x) for x in range(1, max_rounds)] + + +def get_rounds(episode_idx, max_rounds, use_broken_mock_api_calls=False): + return [ + tod_core.TodStructuredRound( + user_utt=f"user_utt_{episode_idx}_{round_idx}", + api_call_machine=make_api_call_machine( + round_idx, episode_idx, use_broken_mock_api_calls + ), + api_resp_machine=make_api_resp_machine(round_idx), + sys_utt=f"sys_utt_{episode_idx}_{round_idx}", + ) + for round_idx in range(max_rounds) + ] + + +def get_round_utts(episode_idx, max_rounds, filter_utts=None): + if max_rounds < 1: + return [] + utts = [ + [ + f"USER: user_utt_{episode_idx}_0", + "APICALL: ", + "APIRESP: ", + f"SYSTEM: sys_utt_{episode_idx}_0", + ] + ] + for i in range(1, max_rounds): + utts.append( + [ + f"USER: user_utt_{episode_idx}_{i}", + f"APICALL: api_name = name_{i} ; in = {i}", + f"APIRESP: out = {i}", + f"SYSTEM: sys_utt_{episode_idx}_{i}", + ] + ) + utts.append( + [ + "USER: [DONE]", + "APICALL: ", + "APIRESP: ", + "SYSTEM: ", + ] + ) + if filter_utts is not None: + utts = [ + [turn for i, turn in enumerate(round_data) if filter_utts[i]] + for round_data in utts + ] + return utts + + +TEST_NUM_EPISODES_OPT_KEY = "test_num_episodes" +TEST_NUM_ROUNDS_OPT_KEY = "test_num_rounds" + +# No api calls in this setup +EPISODE_SETUP__UTTERANCES_ONLY = { + TEST_NUM_ROUNDS_OPT_KEY: 1, + TEST_NUM_EPISODES_OPT_KEY: 1, +} + +# No one call, one goal, one api desscription in this setup +EPISODE_SETUP__SINGLE_API_CALL = { + TEST_NUM_ROUNDS_OPT_KEY: 2, + TEST_NUM_EPISODES_OPT_KEY: 1, +} +# Will start testing multiple api calls + schemas, multi-round logic +EPISODE_SETUP__MULTI_ROUND = {TEST_NUM_ROUNDS_OPT_KEY: 5, TEST_NUM_EPISODES_OPT_KEY: 1} + +# Test that episode logic is correct +EPISODE_SETUP__MULTI_EPISODE = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 8, +} + +# Test that episode + pesky-off-by-one batchinglogic is correct +EPISODE_SETUP__MULTI_EPISODE_BS = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 35, +} + + +class TestDataParser(tod_agents.TodStructuredDataParser): + """ + Assume that when we init, we init w/ num of episodes + rounds as opts. + """ + + def __init__(self, opt, shared=None): + opt["datafile"] = "DUMMY" + self.fold = "DUMMY" + # Following lines are only reelvant in training the standalone api teacher + if TEST_NUM_EPISODES_OPT_KEY not in opt: + opt[TEST_NUM_EPISODES_OPT_KEY] = 35 + if TEST_NUM_ROUNDS_OPT_KEY not in opt: + opt[TEST_NUM_ROUNDS_OPT_KEY] = 5 + super().__init__(opt, shared) + + def setup_episodes(self, _): + result = [] + for ep_idx in range(0, self.opt[TEST_NUM_EPISODES_OPT_KEY]): + result.append( + tod_core.TodStructuredEpisode( + goal_calls_machine=[ + make_api_call_machine(x) + for x in range(1, self.opt[TEST_NUM_ROUNDS_OPT_KEY]) + ], + api_schemas_machine=make_api_schemas_machine( + self.opt[TEST_NUM_ROUNDS_OPT_KEY] + ), + rounds=get_rounds( + ep_idx, + self.opt[TEST_NUM_ROUNDS_OPT_KEY], + self.opt.get("use_broken_mock_api_calls", False), + ), + ) + ) + # print(result, self.opt) + return result + + def get_id_task_prefix(self): + return "Test" + + +class SystemTeacher(TestDataParser, tod_agents.TodSystemTeacher): + pass + + +class UserSimulatorTeacher(TestDataParser, tod_agents.TodUserSimulatorTeacher): + pass + + +class StandaloneApiTeacher(TestDataParser, tod_agents.TodStandaloneApiTeacher): + pass + + +class GoalAgent(TestDataParser, tod_agents.TodGoalAgent): + pass + + +class ApiSchemaAgent(TestDataParser, tod_agents.TodApiSchemaAgent): + pass + + +class SingleGoalAgent(TestDataParser, tod_agents.TodSingleGoalAgent): + pass + + +class SingleApiSchemaAgent(TestDataParser, tod_agents.TodSingleApiSchemaAgent): + pass + + +# Tested in tod world code +class UserUttAgent(TestDataParser, tod_agents.TodUserUttAgent): + pass + + +# Tested in tod world code +class ApiCallAndSysUttAgent(TestDataParser, tod_agents.TodApiCallAndSysUttAgent): + pass diff --git a/pytest.ini b/pytest.ini index 9aec92c0a89..d4095288194 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,3 +12,4 @@ markers = unit internal nofbcode + tod diff --git a/tests/tod/__init__.py b/tests/tod/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/tests/tod/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py new file mode 100644 index 00000000000..5383c72416d --- /dev/null +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests different (more complicated) slot metrics. +""" + +import unittest + +import copy +import parlai.core.tod.tod_core as tod_core +import parlai.core.tod.tod_test_utils.test_agents as test_agents + + +class TestTodAgentsAndTeachersBase(unittest.TestCase): + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + full_opts["episodes_randomization_seed"] = -1 # no random here + return class_type(full_opts) + + def dump_single_utt_per_episode_agent_text(self, class_type, round_opt, opt): + agent = self.setup_agent_or_teacher(class_type, round_opt, opt) + result = [] + while not agent.epoch_done(): + result.append(agent.act()["text"]) + agent.reset() + return result + + def dump_teacher_text(self, class_type, round_opt, opt): + """ + Array where [episode_idx][turn_idx][text=0,label=1] + """ + teacher = self.setup_agent_or_teacher(class_type, round_opt, opt) + data = [] + here = [] + for x, new in teacher.setup_data("dummy"): + if new and len(here) > 0: + data.append(copy.deepcopy(here)) + here = [] + here.append([x["text"], x["label"]]) + if len(here) > 0: + data.append(here) + return data + + def _test_roundDataCorrect(self): + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__UTTERANCES_ONLY) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + + +class TestSystemTeacher(TestTodAgentsAndTeachersBase): + def test_apiSchemas_with_yesApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": True}, + ) + self.assertEqual( + values[0][0][0], + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(2) + ), + ) + + def test_apiSchemas_with_noApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": False}, + ) + self.assertEqual(values[0][0][0], "APIS: ") + + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.SystemTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + for utt in utts: + comp.append([utt[0], utt[1]]) + comp.append([utt[2], utt[3]]) + # Skip context turn cause we check it above + self.assertEqual(episode[1:], comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestUserTeacher(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.UserSimulatorTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + comp.append( + [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ), + utts[0][0], + ] + ) + last_sys = utts[0][3] + for i in range(1, len(utts)): + comp.append([last_sys, utts[i][0]]) + last_sys = utts[i][3] + self.assertEqual(episode, comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.GoalAgent, config, {} + ) + + goal_text = [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.ApiSchemaAgent, config, {} + ) + + apis_texts = [ + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, apis_texts) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + + goal_text = [] + for _ in range(max_episodes): + goals = test_agents.make_goal_calls_machine(max_rounds) + for x in goals: + goal_text.append( + "GOAL: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + apis_text = [] + for _ in range(max_episodes): + apis = test_agents.make_api_schemas_machine(max_rounds) + for x in apis: + apis_text.append( + "APIS: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + self.assertEqual(values, apis_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalWithSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + """ + Make sure the SingleGoal + SingleApiSchema agents correspond. + """ + + def _test_roundDataCorrect_helper(self, config): + goals = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + apis = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + for i in range(len(goals)): + goal = tod_core.SerializationHelpers.str_to_goals(goals[i][len("GOALS:") :]) + api = tod_core.SerializationHelpers.str_to_api_schemas( + apis[i][len("APIS:") :] + ) + self.assertEqual( + goal[0].get("api_name", None), api[0].get("api_name", None) + ) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestLowShot(TestTodAgentsAndTeachersBase): + FEW_SHOT_SAMPLES = [0, 1, 5, 15] + PERCENTAGES = [0, 0.1, 0.3, 0.5] + + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + return class_type(full_opts) + + def test_few_shot_lengths_correct(self): + def helper(n_shot): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, + { + "episodes_randomization_seed": 0, + "n_shot": n_shot, + }, + ) + self.assertEqual(len(values), n_shot) + + for i in self.FEW_SHOT_SAMPLES: + helper(i) + + def _test_subsets(self, data_dumps): + for i in range(len(data_dumps) - 1): + small = data_dumps[i] + larger = data_dumps[i + 1] + for i, episode in enumerate(small): + self.assertEqual(episode, larger[i]) + + def test_few_shot_subset(self): + def helper(n_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE, + { + "episodes_randomization_seed": seed, + "n_shot": n_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_three) + self.assertNotEqual(data_dumps_seed_zero[-1], data_dumps_seed_three[-1]) + + def test_percent_shot_lengths_correct(self): + def helper(percent_shot, correct): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": 0, + "percent_shot": percent_shot, + }, + ) + self.assertEqual(len(values), correct) + + helper(0, 0) + helper(0.1, 3) + helper(0.3, 10) + + def test_percent_shot_subset(self): + def helper(percent_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": seed, + "percent_shot": percent_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_three) + + def test_correct_throw_when_both_shots_defined(self): + self.assertRaises( + RuntimeError, + self.dump_teacher_text, + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + {"episodes_randomization_seed": 0, "percent_shot": 0.3, "n_shot": 3}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py new file mode 100644 index 00000000000..aec4aba40f6 --- /dev/null +++ b/tests/tod/test_tod_teacher_metrics.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from math import isnan + +from parlai.core.metrics import AverageMetric +from parlai.core.tod.teacher_metrics import SlotF1Metric, SlotMetrics + + +class TestSlotF1Metric(unittest.TestCase): + """ + Test SlotF1Metric. + """ + + def test_slot_f1_metric_inputs(self): + slots_p_r_and_f1 = [ + (None, None, float("nan")), + (None, AverageMetric(0.0), float("nan")), + (AverageMetric(0.0), AverageMetric(0.0), float("nan")), + (AverageMetric(1), AverageMetric(1), 1.0), + (AverageMetric(1), AverageMetric(0), 0.0), + (AverageMetric(0.25), AverageMetric(0.75), 0.375), + ] + for slot_p, slot_r, slot_f1 in slots_p_r_and_f1: + actual_slot_f1 = SlotF1Metric(slot_p=slot_p, slot_r=slot_r).value() + if isnan(slot_f1): + self.assertTrue(isnan(actual_slot_f1)) + else: + self.assertEqual(slot_f1, actual_slot_f1) + + def test_slot_f1_metric_addition(self): + a = SlotF1Metric(slot_p=1) + b = SlotF1Metric(slot_r=0) + c = SlotF1Metric(slot_p=AverageMetric(numer=2, denom=3), slot_r=1) + d = a + b + c + # Slot P should be 3/4 = 0.75; slot R should be 1/2 = 0.5 + self.assertEqual(0.6, d.value()) + + +empty_slots = {} +basic_slots = {"a": "a_val", "b": "b_val", "c": "c_val"} +partial_slots = {"a": "a_val", "other": "other_val"} + + +class TestSlotMetrics(unittest.TestCase): + def test_base_slot_metrics(self): + cases = [ + (empty_slots, empty_slots, {"jga": 1}), + ( + basic_slots, + basic_slots, + {"jga": 1, "slot_p": 1, "slot_r": 1, "slot_f1": 1}, + ), + ( + basic_slots, + partial_slots, + {"jga": 0, "slot_p": 0.5, "slot_r": float(1.0 / 3), "slot_f1": 0.4}, + ), + ] + for teacher, predicted, result in cases: + metric = SlotMetrics( + teacher_slots=teacher, + predicted_slots=predicted, + ) + for key in result: + self.assertEqual(result[key], metric.report()[key]) + + +if __name__ == "__main__": + unittest.main() From 638eb286ed5bc14a5e627e9650c4cb5d709fb4e4 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 20:29:54 -0800 Subject: [PATCH 07/57] [TOD] World, world metrics, script, tests See documentation in `tod_world_script.py` for usage. --- parlai/core/tod/README.md | 6 + parlai/core/tod/tod_core.py | 7 +- .../tod_test_utils/standalone_api_file.pickle | Bin 0 -> 232 bytes parlai/core/tod/tod_test_utils/test_agents.py | 1 - parlai/core/tod/tod_world.py | 307 +++++++++++++ parlai/core/tod/world_metrics.py | 116 +++++ parlai/core/tod/world_metrics_handlers.py | 181 ++++++++ .../scripts/distributed_tod_world_script.py | 33 ++ parlai/scripts/tod_world_script.py | 408 +++++++++++++++++ projects/tod_simulator/README.md | 1 - .../world_metrics/extended_world_metrics.py | 411 ++++++++++++++++++ tests/tod/test_tod_world_and_script.py | 225 ++++++++++ tests/tod/test_tod_world_metrics.py | 243 +++++++++++ tests/tod/test_tod_world_metrics_in_script.py | 257 +++++++++++ tests/tod/test_tod_world_script_metrics.py | 154 +++++++ 15 files changed, 2346 insertions(+), 4 deletions(-) create mode 100644 parlai/core/tod/README.md create mode 100644 parlai/core/tod/tod_test_utils/standalone_api_file.pickle create mode 100644 parlai/core/tod/tod_world.py create mode 100644 parlai/core/tod/world_metrics.py create mode 100644 parlai/core/tod/world_metrics_handlers.py create mode 100644 parlai/scripts/distributed_tod_world_script.py create mode 100644 parlai/scripts/tod_world_script.py delete mode 100644 projects/tod_simulator/README.md create mode 100644 projects/tod_simulator/world_metrics/extended_world_metrics.py create mode 100644 tests/tod/test_tod_world_and_script.py create mode 100644 tests/tod/test_tod_world_metrics.py create mode 100644 tests/tod/test_tod_world_metrics_in_script.py create mode 100644 tests/tod/test_tod_world_script_metrics.py diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md new file mode 100644 index 00000000000..f1ebc50e4d6 --- /dev/null +++ b/parlai/core/tod/README.md @@ -0,0 +1,6 @@ +# Core classes for Task-Oriented Dialog (TOD) + +For understanding usage of these classes, start with `tod_agents.py` (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format. + +As a convention, files that have elements that are expected to be referenced outside of this directory (and outisde of `parlai/projects/tod_simulator`) are prefixed wth `tod_`. + diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py index 76ee8005a74..e53be5c7376 100644 --- a/parlai/core/tod/tod_core.py +++ b/parlai/core/tod/tod_core.py @@ -160,7 +160,10 @@ def inner_list_join(cls, values): @classmethod def inner_list_split(cls, s): - return s.split(", ") + split = s.split(", ") + if len(split) == 1: + return split[0] + return set(split) @classmethod def maybe_inner_list_join(cls, values): @@ -193,7 +196,7 @@ def str_to_api_dict(cls, string): continue name, value = slot_str.split(" = ", 1) name = name.strip() - value = value.strip() + value = SerializationHelpers.inner_list_split(value.strip()) result[name] = value return result diff --git a/parlai/core/tod/tod_test_utils/standalone_api_file.pickle b/parlai/core/tod/tod_test_utils/standalone_api_file.pickle new file mode 100644 index 0000000000000000000000000000000000000000..b27e3b297415f56f1d054caf3b4045eec85505a0 GIT binary patch literal 232 zcmZo*t}SHHh>&7nU`Q;;jL%EVO;xZ}08#OV3f2mlc|e|FA!CF9P=RBBXOL@ffR#di zX$e@E39CLMm_DOIW^DS53R$q~GluCiE@Z`~&$y5ct3DH$K9fRrZ2C+JIZE{a-J&|v literal 0 HcmV?d00001 diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b1339052764..b7639efea36 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -171,7 +171,6 @@ def setup_episodes(self, _): ), ) ) - # print(result, self.opt) return result def get_id_task_prefix(self): diff --git a/parlai/core/tod/tod_world.py b/parlai/core/tod/tod_world.py new file mode 100644 index 00000000000..d95a8706478 --- /dev/null +++ b/parlai/core/tod/tod_world.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Class for running task-oriented dialogue chats. See long comment on TodWorld for +functionality descriptions. + +Metrics calculated from these simulations are documented in `world_metrics.py` (for +general usage) and `world_metrics_handlers.py` (for specific metric calculations) +""" +from parlai.core.metrics import Metric, LegacyMetric +from parlai.core.message import Message +from parlai.core.opt import Opt +from parlai.core.worlds import World +from parlai.agents.local_human.local_human import LocalHumanAgent +from parlai.utils.misc import display_messages + +import parlai.core.tod.tod_core as tod +import parlai.core.tod.world_metrics as tod_metrics + +import sys +import copy + +# Following needs to be kept consistent with opt settings/tod script +USER_UTT_IDX = 0 +API_CALL_IDX = 1 +API_RESP_IDX = 2 +SYSTEM_UTT_IDX = 3 +API_SCHEMA_GROUNDING_IDX = 4 +GOAL_GROUNDING_IDX = 5 +AGENT_COUNT = 6 + +SPEAKER_TO_NAME = { + USER_UTT_IDX: tod.TodAgentType.USER_UTT_AGENT, + API_CALL_IDX: tod.TodAgentType.API_CALL_AGENT, + API_RESP_IDX: tod.TodAgentType.API_RESP_AGENT, + SYSTEM_UTT_IDX: tod.TodAgentType.SYSTEM_UTT_AGENT, + API_SCHEMA_GROUNDING_IDX: tod.TodAgentType.API_SCHEMA_GROUNDING_AGENT, + GOAL_GROUNDING_IDX: tod.TodAgentType.GOAL_GROUNDING_AGENT, +} + +NAME_TO_IDX = {v: k for k, v in SPEAKER_TO_NAME.items()} + + +class TodWorld(World): + """ + Base world for running TOD model-model chats. Following agents. + + * User utt agent + * API call agent + * Currently assumed to be same as system utt agent in script code, though used as if separate in this world. + * API responder agent + * System utt agent + * API schema groundinger agent (given to api call + response agent) + * Goal groundinger agent (given to user) + + As is standard for ParlAI, these agents may be models or may be standalone classes that extend the "Agent" class. The models for these *are* expected to have their utterances in a standard format. + + Note that we expect these to be passed in via the opt manually, since some assumptions of regular ParlAI Worlds (ex. task = agent[0], model = agent[1]) are broken here since there is no "task agent" and one agent can be two "roles" (ex. system agent also making API calls) + """ + + def __init__(self, opt: Opt, agents=None, shared=None): + super().__init__(opt, agents, shared) + self.batchsize = opt["batchsize"] + self.batch_agents = [] + self.batch_acts = [] + self.batch_goals = [] # for case when num_episodes < batchsize + self.batch_tod_world_metrics = [] + for i in range(self.batchsize): + here_agents = [] + for j, agent in enumerate(agents): + if ( + j == SYSTEM_UTT_IDX + ): # handle separately cause we expect it to be same as API_CALL agent + here_agents.append(here_agents[API_CALL_IDX]) + continue + share = agent.share() + batch_opt = copy.deepcopy(share["opt"]) + batch_opt["batchindex"] = i + here_agents.append(share["class"](batch_opt, share)) + self.batch_agents.append(here_agents) + self.batch_acts.append([Message.padding_example()] * 4) + self.batch_tod_world_metrics.append(tod_metrics.TodMetrics()) + self.end_episode = [False] * self.batchsize + + self.max_turns = self.opt.get("max_turns", 30) + self.turns = 0 + self.need_grounding = True + + def grounding(self): + """ + Preempt with goal and schema-based intent schemas. + + As a logging hack, we stick the schema gronding in as a user utterance, but + manually pass the value in to the relevant API call/resp agent, since passing it + to the API call agent elsewhere is a little awkward. Similarly, we stick the + goal as a system utterance so that it is captured in logging. However, we do not + pass it in manually, since getting the user utterance will be the first turn of + `parley()`. + """ + self._observe_and_act( + SYSTEM_UTT_IDX, # Doesn't matter, empty at this point + USER_UTT_IDX, # Hack in to a place that'll look nice when printing + f"getting API schema grounding. (Must start with `{tod.STANDARD_API_SCHEMAS}`)", + API_SCHEMA_GROUNDING_IDX, + ) + + self._observe_and_act( + USER_UTT_IDX, + API_CALL_IDX, + "responding to api schema grounding (empty enter is usually fine) ", + ) + self._observe_and_act( + USER_UTT_IDX, + API_RESP_IDX, + "responding to api schema grounding (empty enter is usually fine)", + ) + + self._observe_and_act( + SYSTEM_UTT_IDX, # Doesn't matter for the most part, but want something empty + SYSTEM_UTT_IDX, # Hack into a place per comment above + f"getting goal grounding. (Must start with `{tod.STANDARD_GOAL}`)", + GOAL_GROUNDING_IDX, + ) + self.batch_goals = [act[SYSTEM_UTT_IDX] for act in self.batch_acts] + self.turns = 0 + + def parley(self): + if self.need_grounding: + self.grounding() + self.need_grounding = False + + else: + self._observe_and_act(SYSTEM_UTT_IDX, USER_UTT_IDX) + self._observe_and_act(USER_UTT_IDX, API_CALL_IDX) + self._observe_and_act(API_CALL_IDX, API_RESP_IDX) + self._observe_and_act(API_RESP_IDX, SYSTEM_UTT_IDX) + + self.turns += 1 + self.update_counters() + + def _observe_and_act( + self, observe_idx, act_idx, info="for regular parley", override_act_idx=None + ): + act_agent_idx = override_act_idx if override_act_idx else act_idx + act_agent = self.agents[act_agent_idx] + record_output_idx = act_idx + if hasattr(act_agent, "batch_act"): + batch_observations = [] + for i in range(self.batchsize): + if not self.end_episode[i]: + observe = self.batch_acts[i][observe_idx] + observe = self.batch_agents[i][act_agent_idx].observe(observe) + batch_observations.append(Message(observe)) + else: + # We're done with this episode, so just do a pad. + # NOTE: This could cause issues with RL down the line + batch_observations.append(Message.padding_example()) + self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} + batch_actions = act_agent.batch_act(batch_observations) + for i in range(self.batchsize): + if self.end_episode[i]: + continue + self.batch_acts[i][record_output_idx] = batch_actions[i] + self.batch_agents[i][record_output_idx].self_observe(batch_actions[i]) + else: # Run on agents individually + for i in range(self.batchsize): + act_agent = ( + self.batch_agents[i][override_act_idx] + if override_act_idx + else self.batch_agents[i][act_idx] + ) + if hasattr(act_agent, "episode_done") and act_agent.episode_done(): + self.end_episode[i] = True + if self.end_episode[i]: + # Following line exists because: + # 1. Code for writing converseations is not hapy if an "id" does not exists with a sample + # 2. Because of the `self.end_episode` code, no agent will see this example anyway. + self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} + continue + act_agent.observe(self.batch_acts[i][observe_idx]) + if isinstance(act_agent, LocalHumanAgent): + print( + f"Getting message for {SPEAKER_TO_NAME[record_output_idx]} for {info} in batch {i}" + ) + try: + self.batch_acts[i][record_output_idx] = act_agent.act() + except StopIteration: + self.end_episode[i] = True + for i in range(self.batchsize): + if self.end_episode[i]: + continue + self.batch_tod_world_metrics[i].handle_message( + self.batch_acts[i][record_output_idx], SPEAKER_TO_NAME[act_agent_idx] + ) + if tod.STANDARD_DONE in self.batch_acts[i][record_output_idx].get( + "text", "" + ): + # User models trained to output a "DONE" on last turn; same with human agents. + self.end_episode[i] = True + + def report(self): + """ + Report all metrics of all subagents + of this world in aggregate. + """ + metrics_separate = [] + for i in range(self.batchsize): + here_metrics = self.batch_tod_world_metrics[i].report() + for name, agent in [ + (SPEAKER_TO_NAME[j], self.batch_agents[i][j]) + for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX] + ]: + name_prefix = name[:-6] # strip "_agent" + if hasattr(agent, "report"): + m = agent.report() + if m is None: + continue + for k, v in m.items(): + if not isinstance(v, Metric): + v = LegacyMetric(v) + here_metrics[f"{name_prefix}_{k}"] = v + metrics_separate.append(here_metrics) + metrics = metrics_separate[0] + for i in range(1, self.batchsize): + for k, v in metrics_separate[i].items(): + if k not in metrics: + metrics[k] = v + else: + metrics[k] = metrics[k] + v + return metrics + + def reset(self): + """ + Resets state of world; also sets up episode metrics. + """ + super().reset() + self.need_grounding = True + self.turns = 0 + + self.last_batch_episode_metrics = [] + self.batch_acts = [] + for i in range(self.batchsize): + for agent in self.batch_agents[i]: + agent.reset() + self.batch_acts.append([None] * 4) + + self.batch_tod_world_metrics[i].episode_reset() + metrics = self.batch_tod_world_metrics[i].get_last_episode_metrics() + if metrics: + self.last_batch_episode_metrics.append(metrics) + self.end_episode = [False] * self.batchsize + + def get_last_batch_episode_metrics(self): + return self.last_batch_episode_metrics + + def get_last_batch_goals(self): + return self.batch_goals + + def episode_done(self): + if self.turns >= self.max_turns or all(self.end_episode): + return True + for i in range(self.batchsize): + for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX]: + if ( + self.batch_acts[i][j] is not None + and tod.STANDARD_DONE in self.batch_acts[i][j].get("text", "") + ) or ( + hasattr(self.batch_agents[i][j], "episode_done") + and self.batch_agents[i][j].episode_done() + ): + self.end_episode[i] = True + return all(self.end_episode) + + def epoch_done(self): + for agent in self.agents: + if agent.epoch_done(): + return True + + def num_episodes(self): + result = sys.maxsize + for agent in self.agents: + if hasattr(agent, "num_episodes") and agent.num_episodes() > 0: + result = min(result, agent.num_episodes()) + if result == sys.maxsize: + return 0 + return result + + def get_batch_acts(self): + return self.batch_acts + + def display(self): + s = "[--batchsize " + str(self.batchsize) + "--]\n" + for i in range(self.batchsize): + s += "[batch " + str(i) + ":]\n" + s += display_messages( + self.batch_acts[i], + ignore_agent_reply=self.opt.get("ignore_agent_reply", False), + add_fields=self.opt.get("display_add_fields", ""), + prettify=self.opt.get("display_prettify", False), + max_len=self.opt.get("max_display_len", 1000), + verbose=self.opt.get("verbose", False), + ) + s += "\n" + s += "[--end of batch--]\n" + return s diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py new file mode 100644 index 00000000000..4e8ba4555e4 --- /dev/null +++ b/parlai/core/tod/world_metrics.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Wrapper object holding metrics for TODWorld. + +This class is in its own file to prevent circular dependencies + monolithic files. +""" + +from parlai.core.message import Message +from parlai.core.metrics import ( + Metrics, +) +from parlai.core.tod.tod_core import ( + TodAgentType, + TOD_AGENT_TYPE_TO_PREFIX, + SerializationHelpers, + STANDARD_GOAL, +) +from typing import Any, Dict +import parlai.core.tod.world_metrics_handlers as world_metrics_handlers + +# Change the following to define which Metrics Handlers are used in TodWorld. +# The ones used below are from `world_metrics_handlers.py` only. However, See `parlai/projects/tod_simulator/world_metrics/extended_world_metrics.py` for others. + +WORLD_METRIC_HANDLERS = [ + world_metrics_handlers.AllGoalApiCallSuccessMetricsHandler, + world_metrics_handlers.UserGeneratedDoneMetricHandler, +] + + +class TodMetrics(Metrics): + """ + Helper container which encapsulates TOD metrics and does some basic prepocessing to + handlers to calculate said metrics. + + This class should generally not need to be changed; add new metrics handlers to + `WORLD_METRIC_HANDLERS` (or otherwise override `self.handlers` of this class) to + change metrics actively being used. + """ + + def __init__(self, shared: Dict[str, Any] = None) -> None: + super().__init__(shared=shared) + self.handlers = [x() for x in WORLD_METRIC_HANDLERS] + self.convo_started = False + self.last_episode_metrics = Metrics() + + def handle_message(self, message: Message, agent_type: TodAgentType): + if "text" not in message: + return + if agent_type == TodAgentType.GOAL_GROUNDING_AGENT and len( + message["text"] + ) > len(STANDARD_GOAL): + # Only count a conversation as started if there is a goal. + self.convo_started = True + for handler in self.handlers: + metrics = self._handle_message_impl(message, agent_type, handler) + if metrics is not None: + for name, metric in metrics.items(): + if metric is not None: + self.add(name, metric) + + def _handle_message_impl( + self, + message: Message, + agent_type: TodAgentType, + handler: world_metrics_handlers.TodMetricsHandler, + ): + prefix_stripped_text = message["text"].replace( + TOD_AGENT_TYPE_TO_PREFIX[agent_type], "" + ) + if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: + return handler.handle_api_schemas( + message, + SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + ) + if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: + return handler.handle_goals( + message, SerializationHelpers.str_to_goals(prefix_stripped_text) + ) + if agent_type is TodAgentType.USER_UTT_AGENT: + return handler.handle_user_utt(message, prefix_stripped_text) + if agent_type is TodAgentType.API_CALL_AGENT: + return handler.handle_api_call( + message, SerializationHelpers.str_to_api_dict(prefix_stripped_text) + ) + if agent_type is TodAgentType.API_RESP_AGENT: + return handler.handle_api_resp( + message, SerializationHelpers.str_to_api_dict(prefix_stripped_text) + ) + if agent_type is TodAgentType.SYSTEM_UTT_AGENT: + return handler.handle_sys_utt(message, prefix_stripped_text) + + def get_last_episode_metrics(self): + """ + This is a bit of a hack so that we can report whether or not a convo has + successfully hit all goals and associate this with each episode for the purposes + of doing filtering. + """ + return self.last_episode_metrics + + def episode_reset(self): + self.last_episode_metrics = None + if self.convo_started: + self.last_episode_metrics = Metrics() + for handler in self.handlers: + metrics = handler.get_episode_metrics() + handler.episode_reset() + if metrics is not None: + for name, metric in metrics.items(): + if metric is not None: + self.add(name, metric) + self.last_episode_metrics.add(name, metric) + self.convo_started = False diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py new file mode 100644 index 00000000000..3f4b68477fe --- /dev/null +++ b/parlai/core/tod/world_metrics_handlers.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Metrics handlers - ie, objects that handle generations from Tod World and calculates metrics from them. + +Note that only metrics handler classes in `WORLD_METRIC_HANDLERS` (of `world_metrics.py`) are actively being recorded as metrics. +""" + +from parlai.core.message import Message +from parlai.core.metrics import ( + Metric, + AverageMetric, +) +from parlai.core.tod.tod_core import ( + STANDARD_DONE, +) +from typing import Dict, List, Optional + +METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests + + +def register_metrics_handler(cls): + METRICS_HANDLER_CLASSES_TEST_REGISTRY.add(cls) + return cls + + +class TodMetricsHandler: + """ + Base class for Tod Metrics handlers. Extend this class then add them to + `WORLD_METRIC_HANDLERS` to use. If you would like the class to be exposed to tests, + add the Metrics Handler to `METRICS_HANDLER_CLASSES_TEST_REGISTRY` via annotating + with `@register_metrics_handler`. + + The `TodMetrics` class will, on this class + 1. call `__init__` (which internally calls `episode_reset()`) to begin with. + 2. call each of the `handle..()` functions as the appropriate turns occur + 3. call `get_episode_metrics()` then `episode_reset()` at the end of the episode + + The `handle..()` should be used to set intermediate state within the class and `episode_reset()` should be used to clear this state. + + The output of the `handle..()` and `get_episode_metrics()` functions are both `Optional[Dict[str, Metric]]`s. Metrics from both of these paths will be aggregated and reported to `TodMetrics`, so which one to use is mostly a matter of preference, though + 1. one should take care to only use one or the other and not both, to avoid double-counting + 2. those from `get_episode_metrics()` will be recorded per-episode and saved to `tod_world_script`'s report as well + + `UserGeneratedDoneMetricHandler` in this file, which collects metrics about frequency of seeing the "[DONE]" token on User utterances and also records conversation length, is a fairly straightforward example of usage. + + Other tried (but not in current active use) Metrics Handers are in `projects/tod_simulator/world_metrics/extended_world_metrics.py`. + """ + + def __init__(self): + self.episode_reset() + + def episode_reset(self): + pass + + def handle_api_schemas( + self, message: Message, api_schemas: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.api_schemas = api_schemas + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.goals = goals + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_api_resp( + self, message: Message, api_resp: Dict + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + pass + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + pass + + +################################ +# Functions and classes associated with calculating statistics between API Calls and Goals. +def goals_hit_helper( + goals: List[Dict], turnDict: List[Dict], permissive=False +) -> (AverageMetric, AverageMetric, AverageMetric): + """ + Helper function that aids in seeing if the API calls the system has attempted to + make manages to meet the goals the conversation has. + + Return values: + * if all goals hit + * # of turns it took to hit all goals (or None) + * fraction of goals hit + """ + goals_left = goals + + def exact_match(goal, turn): # if and only if + return goal == turn + + def permissive_match(goal, turn): # guess is superset + for key in goal: + if turn.get(key, "definitelyNotIn") != goal[key]: + return False + return True + + compare_func = permissive_match if permissive else exact_match + + for i, turn in enumerate(turnDict): + goals_left = [goal for goal in goals_left if not compare_func(goal, turn)] + if len(goals_left) == 0: + return AverageMetric(True), AverageMetric(i + 1), AverageMetric(1) + return ( + AverageMetric(False), + AverageMetric(0), + AverageMetric(len(goals) - len(goals_left), len(goals)), + ) + + +class _ApiCallGoalInteractionHelper(TodMetricsHandler): + """ + Base class for storing details about valid API calls (information about Goals + handled in TodMetricsHandler) + """ + + def episode_reset(self): + self.api_turns = [] + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + if len(api_call) > 0: + self.api_turns.append(api_call) + + +@register_metrics_handler +class AllGoalApiCallSuccessMetricsHandler(_ApiCallGoalInteractionHelper): + """ + Calculates synthetic Task Success + related metrics for converseations. + + Test coverage of this class is with `LegacyGoalApiCallInteractionsMetricsHandler` + """ + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + all_goals_hit, _, _ = goals_hit_helper(self.goals, self.api_turns) + call_attempts = len(self.api_turns) + return { + "synthetic_task_success": all_goals_hit, + "api_call_attempts": AverageMetric(call_attempts), + } + + +@register_metrics_handler +class UserGeneratedDoneMetricHandler(TodMetricsHandler): + def episode_reset(self): + self.done_seen = False + self.turn_count = 0 + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + self.done_seen |= STANDARD_DONE in message["text"] + self.turn_count += 1 + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + result = {"done_seen": AverageMetric(self.done_seen)} + if self.done_seen: + result["round_count_done_seen"] = AverageMetric(self.turn_count) + result["rounds_count_all_conversations"] = AverageMetric(self.turn_count) + return result diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py new file mode 100644 index 00000000000..87c8333ec41 --- /dev/null +++ b/parlai/scripts/distributed_tod_world_script.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Distributed script for running TOD model-model chats. + +Not to be called directly; should be called from SLURM +""" + +from parlai.scripts.tod_world_script import ( + TodWorldScript, +) +from parlai.core.script import ParlaiScript +import parlai.utils.distributed as distributed_utils + + +class DistributedTodWorldScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = TodWorldScript.setup_args() + parser.add_distributed_training_args() + parser.add_argument("--port", type=int, default=61337, help="TCP port number") + return parser + + def run(self): + with distributed_utils.slurm_distributed_context(self.opt) as opt: + return TodWorldScript(opt).run() + + +if __name__ == "__main__": + DistributedTodWorldScript.main() diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py new file mode 100644 index 00000000000..af4f963b71f --- /dev/null +++ b/parlai/scripts/tod_world_script.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base script for running TOD model-model chats. +""" + +import json +from copy import deepcopy +from shutil import copyfile +import os + +import parlai.utils.logging as logging +import parlai.core.tod.tod_world as tod_world +import parlai.core.tod.tod_agents as tod_world_agents +from parlai.core.agents import create_agent +from parlai.core.metrics import dict_report, aggregate_unnamed_reports +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript, register_script +from parlai.utils.distributed import ( + is_primary_worker, + all_gather_list, + is_distributed, + get_rank, + sync_object, + num_workers, +) +from parlai.utils.io import PathManager +from parlai.utils.misc import TimeLogger, nice_report +from parlai.utils.world_logging import WorldLogger + + +class TodWorldLogger(WorldLogger): + """ + WorldLogger has most of what we need. + + We could if-class this logic in it directly, but inheritence + override here is + neater. + """ + + def _is_batch_world(self, world): + return True + + def _log_batch(self, world): + batch_acts = world.get_batch_acts() + for i, acts in enumerate(batch_acts): + # filter out for empty + acts = [act for act in acts if act["id"] != "" and act["text"] != ""] + if len(acts) > 0: + self._add_msgs(acts, idx=i) + if world.episode_done(): + self.reset_world(idx=i) + + +class TodWorldParser(ParlaiParser): + def add_extra_args(self, args=None): + super().add_extra_args(args) + parsed = vars(self.parse_known_args(args, nohelp=True)[0]) + # Also load extra args options if a file is given. + if parsed.get("init_opt") is not None: + try: + self._load_known_opts(parsed.get("init_opt"), parsed) + except FileNotFoundError: + # don't die if -o isn't found here. See comment in second call + # later on. + pass + parsed = self._infer_datapath(parsed) + + partial = Opt(parsed) + + for model in [ + "system_model", + "user_model", + "api_schema_grounding_model", + "goal_grounding_model", + "api_resp_model", + ]: + if ( + model in partial + and partial[model] is not None + and len(partial[model]) > 0 + ): + self.add_model_subargs(partial[model], partial) + + for model_file_prefix in ["system", "user"]: + key = model_file_prefix + "_model_file" + if key in partial and partial[key] and len(partial[key]) > 0: + model_name = self._get_model_name_from_model_file(key, partial) + self.add_model_subargs(model_name, partial) + + def _get_model_name_from_model_file(self, key, opt): + """ + Get the model name from either `--model` or `--model-file`. + """ + # try to get model name from model opt file + model_file = opt.get(key, None) + optfile = model_file + ".opt" + new_opt = Opt.load(optfile) + model = new_opt.get("model", None) + return model + + +@register_script("tod_world_script") +class TodWorldScript(ParlaiScript): + @classmethod + def setup_tod_args(cls, parser: ParlaiParser): + tod_args = parser.add_argument_group( + "TOD World Script Agent arguments. NOTE: Agents setup with this path will be able to take command line arguments, whereas those set from `-o` or `--init-opt` will not." + ) + tod_args.add_argument( + "--system-model-file", + default="", + help="Define the system model for the chat. Exactly one of this or system-model must be specified", + ) + + tod_args.add_argument( + "--system-model", + default="", + help="Define the system agent for the chat. Exactly one of this or system-model-file must be specified", + ) + + tod_args.add_argument( + "--user-model-file", + default="", + help="Define the user model for the chat. Exactly one of this user-model must be specified. Currently assumed to be the API Call creation agent as well.", + ) + + tod_args.add_argument( + "--user-model", + default="", + help="Define the user agent for the chat. Exactly one of this or user-model-file must be specified. Currently assumed to be the API Call creation agent as well.", + ) + + tod_args.add_argument( + "--api-resp-model", + default="", + help="Agent used for defining API response values", + ) + + tod_args.add_argument( + "--api-schema-grounding-model", + default="", + help="Agent used in first turn to grounding api call/response agents with api schemas. Will use EmptyApiSchemaAgent if both this and `--api-schemas` not set.", + ) + + tod_args.add_argument( + "--goal-grounding-model", + default="", + help="Agent used in first turn to grounding user agent with goal. Will use EmptyGoalAgent if not set", + ) + + tod_args.add_argument( + "--api-schemas", + default=None, + help="If set and `--api-schema-grounding-model` is empty, will infer `--api-schema-grounding-model` based on this and a regex on `--goal-grounding-model`. If you run into issues with parsing order of opts using this flag, just switch to `--api-schema-grounding-model`.", + ) + + @classmethod + def setup_args(cls): + # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line + parser = TodWorldParser( + True, + False, + "World for chatting with the TOD conversation structure", + ) + # Following params are same as the `eval_model` script + parser.add_argument( + "--report-filename", + type=str, + help="Saves a json file of the evaluation report either as an " + 'extension to the model-file (if begins with a ".") or a whole ' + "file path. Set to the empty string to not save at all.", + ) + parser.add_argument( + "--world-logs", + type=str, + help="Saves a jsonl file containing all of the task examples and " + "model replies.", + ) + parser.add_argument( + "--save-format", + type=str, + default="conversations", + choices=["conversations", "parlai"], + ) + parser.add_argument( + "--num-episodes", + type=int, + default=10, + help="Number of episodes to display. Set to -1 for infinity or the number of examples of the first agent with a non-unlimited number of episodes in the world.", + ) + parser.add_argument("-d", "--display-examples", type="bool", default=False) + parser.add_argument("-ltim", "--log-every-n-secs", type=float, default=10) + TodWorldLogger.add_cmdline_args(parser) + + # Following are specific to TOD World + parser.add_argument( + "--max-turns", + type=int, + default=30, + help="The max number of full turns before chat ends, excluding prompting", + ) + TodWorldScript.setup_tod_args(parser) + + return parser + + def _get_file_or_model_specifiable_agent(self, prefix, opt): + if len(opt.get(f"{prefix}_model_file", "")) > 0: + if len(opt.get(f"{prefix}_model", "")) > 0: + raise KeyError( + "Both `--{prefix}-model-file` and `--{prefix}-model` specified. Exactly one should be." + ) + model = self._make_agent( + opt, + f"{prefix}_model_file", + requireModelExists=True, + opt_key="model_file", + ) + elif len(opt.get(f"{prefix}_model", "")) > 0: + model = self._make_agent(opt, f"{prefix}_model", "") + else: + raise KeyError( + f"Both `--{prefix}-model-file` and `--{prefix}-model` specified. Neither currently set." + ) + return model + + def _get_model_or_default_agent(self, opt, key, default_class): + if len(opt.get(key, "")) > 0: + return self._make_agent(opt, key) + return default_class(opt) + + def _get_tod_agents(self, opt: Opt): + agents = [None] * tod_world.AGENT_COUNT + + agents[tod_world.USER_UTT_IDX] = self._get_file_or_model_specifiable_agent( + "user", opt + ) + + # Get system agent, nothing that api call agent currently same as system agent + system_model = self._get_file_or_model_specifiable_agent("system", opt) + agents[tod_world.SYSTEM_UTT_IDX] = system_model + agents[tod_world.API_CALL_IDX] = system_model + + agents[tod_world.API_RESP_IDX] = self._make_agent(opt, "api_resp_model") + agents[tod_world.GOAL_GROUNDING_IDX] = self._get_model_or_default_agent( + opt, "goal_grounding_model", tod_world_agents.EmptyGoalAgent + ) + + if "api_schema_grounding_model" not in opt and "api_schemas" in opt: + opt["api_schema_grounding_model"] = opt.get( + "goal_grounding_model", "" + ).replace("Goal", "ApiSchema") + + agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( + opt, + "api_schema_grounding_model", + tod_world_agents.EmptyApiSchemaAgent, + ) + + return agents + + def _make_agent(self, opt_raw, name, requireModelExists=False, opt_key="model"): + """ + Hack. + + `create_agent` expects opt[`model`] to specify the model type and we're + specifying multiple models from other opt arguments (ex. + `system_model`/`user_model` etc), so this swaps it in. + """ + opt = deepcopy(opt_raw) + opt[opt_key] = opt[name] + print(opt_key, name) + return create_agent(opt, requireModelExists) + + def _run_episode(self, opt, world, world_logger): + while not world.episode_done(): + world.parley() + world_logger.log(world) + + if opt["display_examples"]: + logging.info(world.display()) + + if opt["display_examples"]: + logging.info("-- end of episode --") + + world.reset() + world_logger.reset_world() # flush this episode + return zip(world.get_last_batch_goals(), world.get_last_batch_episode_metrics()) + + def _save_outputs(self, opt, world, logger, episode_metrics): + if is_distributed(): # flatten everything intelligently if need be + world_report = aggregate_unnamed_reports(all_gather_list(world.report())) + episode_metrics_unflattened = all_gather_list(episode_metrics) + flattened = [] + for rank_elem in episode_metrics_unflattened: + for elem in rank_elem: + flattened.append(elem) + episode_metrics = flattened + else: + world_report = world.report() + logging.report("Final report:\n" + nice_report(world_report)) + + report = dict_report(world_report) + + def get_episode_report(goal, episode_metric): + metrics_dict = dict_report(episode_metric.report()) + metrics_dict["goal"] = goal + return metrics_dict + + report["tod_metrics"] = [get_episode_report(g, e) for g, e in episode_metrics] + + if "report_filename" in opt and opt["report_filename"] is not None: + if len(world_report) == 0: + logging.warning("Report is empty; not saving report") + + report_fname = f"{opt['report_filename']}.json" + # Save report + if not is_distributed() or is_primary_worker(): + with PathManager.open(report_fname, "w") as f: + logging.info(f"Saving model report to {report_fname}") + json.dump({"opt": opt, "report": report}, f, indent=4) + f.write("\n") # for jq + + if "world_logs" in opt and opt["world_logs"] is not None: + if is_distributed(): # Save separately, then aggregate together + rank = get_rank() + log_outfile_part = ( + f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" + ) + logger.write(log_outfile_part, world, file_format=opt["save_format"]) + sync_object(None) + if is_primary_worker(): + log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" + log_outfile_metadata = ( + f"{opt['world_logs']}_{opt['save_format']}.metadata" + ) + with open(log_outfile, "w+") as outfile: + for rank in range(num_workers()): + log_outfile_part = ( + f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" + ) + with open(log_outfile_part) as infile: + for line in infile: + json_blob = json.loads(line.strip()) + if ( + len(json_blob["dialog"]) < 2 + ): # skip when we don't have generation + continue + json_blob["metadata_path"] = log_outfile_metadata + outfile.write(json.dumps(json_blob)) + outfile.write("\n") + log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata" + if rank == 0: + copyfile( + log_output_part_metadata, log_outfile_metadata + ), + os.remove(log_outfile_part) + os.remove(log_output_part_metadata) + else: + log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" + logger.write(log_outfile, world, file_format=opt["save_format"]) + + return report + + def _setup_world(self): + # setup world, manually finaggling necessary opt info as needed + self.opt["task"] = "TodWorld" + world = tod_world.TodWorld(self.opt, agents=self._get_tod_agents(self.opt)) + return world + + def run(self): + opt = self.opt + + world = self._setup_world() + logger = TodWorldLogger(opt) + + # set up logging + log_every_n_secs = opt.get("log_every_n_secs", -1) + if log_every_n_secs <= 0: + log_every_n_secs = float("inf") + log_time = TimeLogger() + + # episode counter + max_episodes = opt.get("num_episodes", -1) + if max_episodes < 0: + max_episodes = float("inf") + world_num_episodes = world.num_episodes() + if world_num_episodes > 0: + max_episodes = min(max_episodes, world_num_episodes) + + ep_count = 0 + episode_metrics = [] + while not world.epoch_done() and ep_count < max_episodes: + episode_metrics.extend(self._run_episode(opt, world, logger)) + ep_count += opt.get("batchsize", 1) + if log_time.time() > log_every_n_secs: + report = world.report() + text, report = log_time.log(ep_count, max_episodes, report) + logging.info(text) + + return self._save_outputs(opt, world, logger, episode_metrics) + + +if __name__ == "__main__": + TodWorldScript.main() diff --git a/projects/tod_simulator/README.md b/projects/tod_simulator/README.md deleted file mode 100644 index 65fbb41753a..00000000000 --- a/projects/tod_simulator/README.md +++ /dev/null @@ -1 +0,0 @@ -Page to be filled. :) diff --git a/projects/tod_simulator/world_metrics/extended_world_metrics.py b/projects/tod_simulator/world_metrics/extended_world_metrics.py new file mode 100644 index 00000000000..7aa4eca89c5 --- /dev/null +++ b/projects/tod_simulator/world_metrics/extended_world_metrics.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Metrics handlers - ie, classes that handle generations from Tod World and calculates metrics from them. + +Note that only metrics handler classes in `WORLD_METRIC_HANDLERS` are actively being recorded as metrics. +""" + +from parlai.core.message import Message +from parlai.core.metrics import ( + Metric, + AverageMetric, + normalize_answer, + BleuMetric, + SumMetric, +) +from parlai.core.tod.tod_core import ( + STANDARD_API_NAME_SLOT, + STANDARD_REQUIRED_KEY, + STANDARD_OPTIONAL_KEY, + STANDARD_API_SCHEMAS, +) +from typing import Dict, List, Optional, Tuple +from parlai.core.tod.world_metrics_handlers import ( + TodMetricsHandler, + register_metrics_handler, + _ApiCallGoalInteractionHelper, + goals_hit_helper, +) + +try: + from nltk.translate import bleu_score as nltkbleu +except ImportError: + # User doesn't have nltk installed, so we can't use it for bleu + # We'll just turn off things, but we might want to warn the user + nltkbleu = None + +################################ +# Functions and classes associated with calculating statistics between API Calls and Goals. + + +def get_req_only_goals(goals_list: List[Dict], api_schemas: List[Dict]) -> List[Dict]: + """ + Given a list of goals and a list of api schemas that say if slots are required or + optional, this function filters for the goals to be only the required ones. + + If we have no api schemas or a goal is malformed, we return the empty list. If a + goal is malformed, we print a warning, since this whole req-only goals thing is + experimental at best anyhow. + """ + if len(api_schemas) == 0: + return [] + result = [] + for goal in goals_list: + req_goals = {} + method = goal.get(STANDARD_API_NAME_SLOT, None) + if method is None: + return [] + required = [] + for schema in api_schemas: + if schema.get(STANDARD_API_NAME_SLOT, "") == method: + required = schema.get(STANDARD_REQUIRED_KEY, {}) + print("-".join(required)) + for key in required: + if key not in goal: + print(f"No required key `{key}` in goal `{goal}`") + return [] + req_goals[key] = goal[key] + if len(req_goals) > 0: + req_goals[STANDARD_API_NAME_SLOT] = method # for consistency with all. + result.append(req_goals) + return result + + +def goals_slots_helper( + goals: List[Dict], turnDict: List[Dict] +) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """ + Helper function to see how well the slot keys + slot values match between attempted + API calls and goals. + + Output is precision, recall. + """ + all_call_slots = {k: v for call in turnDict for k, v in call.items()} + all_goal_slots = {k: v for goal in goals for k, v in goal.items()} + goal_in_call = { + k: v + for k, v in all_call_slots.items() + if all_goal_slots.get(k, "definitelyNotInValuexyz") == v + } + call_in_goal = { + k: v + for k, v in all_goal_slots.items() + if all_call_slots.get(k, "definitelyNotInValuexyz") == v + } + + print(goal_in_call, all_call_slots) + + return ( + AverageMetric(len(goal_in_call), len(all_call_slots)), + AverageMetric(len(call_in_goal), len(all_goal_slots)), + ) + + +@register_metrics_handler +class LegacyGoalApiCallInteractionsMetricsHandler(_ApiCallGoalInteractionHelper): + """ + This class was reporting a few too many metrics, but is useful for test purposes, so + we're keeping it around. + + `AllGoalApiCallSuccessMetricsHandler` is the streamlined, less spammy version of + this class. + """ + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.goals = goals + self.required_goals = get_req_only_goals(goals, self.api_schemas) + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + all_goals_hit, all_goals_hit_turn_count, all_part_hit = goals_hit_helper( + self.goals, self.api_turns + ) + all_precision, all_recall = goals_slots_helper(self.goals, self.api_turns) + req_goals_hit, req_goals_hit_turn_count, req_part_hit = goals_hit_helper( + self.required_goals, self.api_turns, permissive=True + ) + req_precision, req_recall = goals_slots_helper( + self.required_goals, self.api_turns + ) + call_attempts = len(self.api_turns) + return { + "all_goals_hit": all_goals_hit, + "all_goals_hit_turn_count": all_goals_hit_turn_count, + "all_goals_fractional_hit": all_part_hit, + "all_goals_slot_precision": all_precision, + "all_goals_slot_recall": all_recall, + "req_goals_hit": req_goals_hit, + "req_goals_hit_turn_count": req_goals_hit_turn_count, + "req_goals_fractional_hit": req_part_hit, + "req_goals_slot_precision": req_precision, + "req_goals_slot_recall": req_recall, + "call_attempts": AverageMetric(call_attempts), + } + + +@register_metrics_handler +class UserGoalSlotCoverageMetricHandler(TodMetricsHandler): + """ + How well does our user simulator do at outputting utterances that goes closer to + satisfying relevant groundinged goals? Does it dump out all of the slots at once or + is it more intelligent than that? + + Since this is the user and we don't know the identity of potential slots, we ignore + the short (< 4 chars) goal slots since this tends to be things that are substrings + of other things. (Ex. "2" showing up as # of people in a reservation, but also + showing up as a phone number.) + """ + + def episode_reset(self): + self.mentioned_all_slot_values = set() + self.mentioned_req_slot_values = set() + self.all_goal_slot_values = set() + self.all_req_goal_slot_values = set() + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + """ + Parse out all the slots as a blob, filtering out for short things. + """ + required_goals = get_req_only_goals(goals, self.api_schemas) + + def get_slot_values(goal_list): + result = set() + for goal in goal_list: + for key, value in goal.items(): + if key is not STANDARD_API_NAME_SLOT and len(value) > 3: + result.add(value) + return result + + self.all_goal_slot_values = get_slot_values(goals) + self.all_req_goal_slot_values = get_slot_values(required_goals) + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + """ + Grab slots out of the user utterance based on an exact match. + """ + utterance = prefix_stripped_text + + def get_slots(utt, options): + results = set() + for option in options: + if option in utt: + results.add(option) + return results + + all_slot_values_here = get_slots(utterance, self.all_goal_slot_values) + req_slot_values_here = get_slots(utterance, self.all_req_goal_slot_values) + + self.mentioned_all_slot_values |= all_slot_values_here + self.mentioned_req_slot_values |= req_slot_values_here + + metrics = {} + metrics["user_utt_avg_any_slot"] = AverageMetric(len(all_slot_values_here)) + metrics["user_utt_avg_req_slot"] = AverageMetric(len(req_slot_values_here)) + return metrics + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + result = { + "user_any_goal_slots_recall": AverageMetric( + len(self.mentioned_all_slot_values), len(self.all_goal_slot_values) + ), + "user_req_goal_slots_recall": AverageMetric( + len(self.mentioned_req_slot_values), len(self.all_req_goal_slot_values) + ), + } + + self.mentioned_all_slot_values = set() + self.mentioned_req_slot_values = set() + return result + + +class _ExactRepeatMetricsHandler(TodMetricsHandler): + """ + Helper class for defining % of episodes where a given agent type has exactly + repeated the same utterance. + """ + + def episode_reset(self): + self.turns = [] + self.repeated = False + + def metric_key(self): + raise NotImplementedError("must implement") + + def handle_message_helper( + self, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + normalized = normalize_answer(prefix_stripped_text) + if normalized in self.turns: + self.repeated = True + self.turns.append(normalized) + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + repeat = int(self.repeated) + self.repeated = False + self.turns = [] + return {self.metric_key(): AverageMetric(repeat)} + + +@register_metrics_handler +class UserUttRepeatMetricHandler(_ExactRepeatMetricsHandler): + def metric_key(self): + return "user_utt_repeat" + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class SystemUttRepeatMetricHandler(_ExactRepeatMetricsHandler): + def metric_key(self): + return "sys_utt_repeat" + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class _Bleu3MetricsHandler(TodMetricsHandler): + """ + For a given agent, this calculates the Bleu-3 of a new turn against prior turns. + + This is an alternate metric for repetativeness + """ + + def episode_reset(self): + self.turns = [] + + def metric_key(self): + raise NotImplementedError("must implement") + + def handle_message_helper( + self, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + here = [normalize_answer(x) for x in prefix_stripped_text.split(" ")] + score = 1 + if len(self.turns) > 0: + score = nltkbleu.corpus_bleu( + [self.turns], + [here], + smoothing_function=nltkbleu.SmoothingFunction(epsilon=1e-12).method1, + weights=[1.0 / 3.0] * 3, + ) + self.turns.append(here) + return {self.metric_key(): BleuMetric(score)} + + +@register_metrics_handler +class UserUttSelfBleu3MetricHandler(_Bleu3MetricsHandler): + def metric_key(self): + return "user_utt_self_bleu3" + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class SystemUttSelfBleu3MetricHandler(_Bleu3MetricsHandler): + def metric_key(self): + return "sys_utt_self_bleu3" + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class ApiCallMalformedMetricHandler(TodMetricsHandler): + def episode_reset(self): + self.api_schemas = [] + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + if STANDARD_API_SCHEMAS in message["text"]: + return # Happens for API call groundingion, so it's fine + if len(api_call) == 0: + return + if STANDARD_API_NAME_SLOT not in api_call: + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_hasSlotsButNoApiNameSlot_count": SumMetric(1), + } + method = api_call[STANDARD_API_NAME_SLOT] + + method_found = False + if len(self.api_schemas) > 0: + for schema in self.api_schemas: + if method == schema.get(STANDARD_API_NAME_SLOT, ""): + method_found = True + check = api_call.keys() + required = set(schema.get(STANDARD_REQUIRED_KEY, [])) + required.add(STANDARD_API_NAME_SLOT) + for req in required: + if req not in check: # miissing required + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_missingRequiredSlot_count": SumMetric(1), + } + opt_count = 0 + for opt in schema.get(STANDARD_OPTIONAL_KEY, []): + if opt in check: + opt_count += 1 + if opt_count + len(required) != len(check): + # have extra APIs that are not + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_hasExtraParams_count": SumMetric(1), + } + break + if method_found: + return { + "apiCall_wellFormed": AverageMetric(1), + "apiCall_wellFormed_count": SumMetric(1), + } + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_methodDNE_count": SumMetric(1), + } + + +@register_metrics_handler +class PseudoInformMetricsHandler(TodMetricsHandler): + """ + Pseudo-inform rate. + """ + + def episode_reset(self): + self.api_resp_slots = {} + + def handle_api_resp( + self, message: Message, api_resp: Dict + ) -> Optional[Dict[str, Metric]]: + self.api_resp_slots.update(api_resp) + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + count = 0 + for val in self.api_resp_slots.values(): + if val in prefix_stripped_text: + count += 1 + result = {"pseudo_inform_allSysTurns": AverageMetric(count)} + if len(self.api_resp_slots) > 0: + result["pseudo_inform_postApiRespSysTurns"] = AverageMetric(count) + return result diff --git a/tests/tod/test_tod_world_and_script.py b/tests/tod/test_tod_world_and_script.py new file mode 100644 index 00000000000..d3f764bc13b --- /dev/null +++ b/tests/tod/test_tod_world_and_script.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests tod world, notably for batching. +""" + +import copy +import unittest + +import parlai.core.tod.tod_test_utils.test_agents as test_agents +import parlai.core.tod.tod_core as tod_core +import parlai.scripts.tod_world_script as tod_world_script +from parlai.core.tod.tod_agents import StandaloneApiAgent + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def _get_tod_agents(self, opt): + """ + Hack so we can separate out logic of making sure agent parsing is correct. + """ + if hasattr(self, "agents"): + return self.agents + return super()._get_tod_agents(opt) + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.logger = logger + + +class TodWorldInScriptTestBase(unittest.TestCase): + def add_tod_world_opts(self, base_opts): + """ + Convenience since we're initing the opt directly without parlai parser. + """ + opts = copy.deepcopy(base_opts) + opts["datatype"] = "DUMMY" + opts["datafile"] = "DUMMY" + opts["episodes_randomization_seed"] = -1 + opts["standalone_api_file"] = test_agents.API_DATABASE_FILE + opts["exact_api_call"] = True + opts["log_keep_fields"] = "all" + opts["display_examples"] = False + opts[ + "include_api_schemas" + ] = True # do this to test_agents.make sure they're done correctly. + return opts + + def setup_agents(self, added_opts): + full_opts = self.add_tod_world_opts(added_opts) + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + test_agents.ApiSchemaAgent(full_opts), + test_agents.GoalAgent(full_opts), + ] + return agents, full_opts + + def _test_roundDataCorrect(self): + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE_BS) + + def _check_correctness_from_script_logs( + self, script, opt, process_round_utts=lambda x: x + ): + """ + Last argument is only relevant for the max_turn test. + """ + max_rounds = opt[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] + # there's something funky with logger.get_log() that inserts a space, but not gonna worry about it for now + logs = [x for x in script.logger.get_logs() if len(x) > 0] + for episode_idx in range(max_episodes): + episode_from_world = logs[episode_idx] + # first round is context + context = episode_from_world[0] + self.assertEquals( + context[0]["text"], + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(max_rounds) + ), + ) + self.assertEquals( + context[3]["text"], + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ), + ) + # Check the rest + world_utts = [[x["text"] for x in turn] for turn in episode_from_world[1:]] + # ... ignore the last DONE turn here cause it's not that important + + self.assertEquals( + world_utts[:-1], + process_round_utts( + test_agents.get_round_utts(episode_idx, max_rounds)[:-1] + ), + ) + + +class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + def _test_roundDataCorrect_helper(self, config): + config["batchsize"] = 1 + config["max_turns"] = 10 + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_correctness_from_script_logs(script, opt) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + def test_max_turn(self): + self._test_max_turn_helper(4) + self._test_max_turn_helper(7) + + def _test_max_turn_helper(self, max_turns): + config = {} + config["batchsize"] = 1 + config["max_turns"] = max_turns + config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + + def filter_round_utt(utts): + # tad imprecise, but more important that it does stop. + # subtract 1 for the context turn, then 1 cause there's an off by one somewhere + return utts[: max_turns - 2] + + self._check_correctness_from_script_logs(script, opt, filter_round_utt) + + +class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + def _test_roundDataCorrect_helper(self, config): + config["batchsize"] = 4 + config["max_turns"] = 10 + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_correctness_from_script_logs(script, opt) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TodWorldTestSingleDumpAgents(TodWorldInScriptTestBase): + def setup_agents(self, added_opts, api_agent, goal_agent): + full_opts = self.add_tod_world_opts(added_opts) + full_opts["fixed_response"] = "USER: [DONE]" + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + api_agent(full_opts), + goal_agent(full_opts), + ] + return agents, full_opts + + def test_SingleGoalApiResp_noBatching(self): + config = {} + config["batchsize"] = 1 + config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + single_agents, opt = self.setup_agents( + config, test_agents.SingleApiSchemaAgent, test_agents.SingleGoalAgent + ) + single_script = TestTodWorldScript(opt) + single_script.agents = single_agents + single_script.run() + single_logs = [x for x in single_script.logger.get_logs() if len(x) > 0] + + multi_agents, opt = self.setup_agents( + config, test_agents.ApiSchemaAgent, test_agents.GoalAgent + ) + multi_script = TestTodWorldScript(opt) + multi_script.agents = multi_agents + multi_script.run() + multi_logs = [x for x in single_script.logger.get_logs() if len(x) > 0] + + single_idx = 0 + for multi_log in multi_logs: + context = multi_log[0] + goals = tod_core.SerializationHelpers.str_to_goals( + context[3]["text"][len("GOAL:") :].strip() + ) + for goal in goals: + single_context = single_logs[single_idx][0] + single_goal = tod_core.SerializationHelpers.str_to_goals( + single_context[3]["text"][len("GOAL:") :].strip() + ) + self.assertEqual(len(single_goal), 1) + self.assertEquals(goal, single_goal[0]) + single_des = tod_core.SerializationHelpers.str_to_api_schemas( + single_context[0]["text"][len("APIS:") :].strip() + ) + self.assertEqual(len(single_des), 1) + self.assertEqual(single_goal[0]["api_name"], single_des[0]["api_name"]) + + single_idx += 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py new file mode 100644 index 00000000000..0367f655994 --- /dev/null +++ b/tests/tod/test_tod_world_metrics.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests different (more complicated) slot metrics. +""" + +import unittest + +from parlai.core.tod.tod_core import ( + STANDARD_API_NAME_SLOT, + STANDARD_REQUIRED_KEY, + STANDARD_OPTIONAL_KEY, + TodStructuredRound, + TodStructuredEpisode, + TodAgentType, + TOD_AGENT_TYPE_TO_PREFIX, +) +from parlai.core.tod.world_metrics import ( + TodMetrics, +) +from parlai.core.tod.world_metrics_handlers import ( + METRICS_HANDLER_CLASSES_TEST_REGISTRY, +) + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + +GOAL__SINGLE_ONE_KEY = [{STANDARD_API_NAME_SLOT: "name", "a": "1"}] +GOAL__SINGLE_THREE_KEYS = [ + {STANDARD_API_NAME_SLOT: "name", "a": "1", "b": "2", "c": "3"} +] +GOAL__HARD = [ + { + STANDARD_API_NAME_SLOT: "otherName", + "w": "1", + "x": "2", + "y": "3", + "z": "will_be_missing", + "diff": "right", + } +] + +API_CALL__NO_API_NAME_SLOT = {"random": "blah"} +API_CALL__API_NAME_DNE = {STANDARD_API_NAME_SLOT: "not_an_api_name"} +API_CALL__VALID_NAME_BUT_EMPTY = {STANDARD_API_NAME_SLOT: "name"} +API_CALL__SINGLE_ONE_KEY = GOAL__SINGLE_ONE_KEY[0] +API_CALL__SINGLE_ONE_KEY_WITH_OPT = {**GOAL__SINGLE_ONE_KEY[0], **{"c": "3"}} +API_CALL__SINGLE_ONE_KEY_WITH_OPT_AND_NONVALID = { + **GOAL__SINGLE_ONE_KEY[0], + **{"c": "3", "nonExistent": "blah"}, +} +API_CALL__FUNKY_AGAINST_HARD = { + STANDARD_API_NAME_SLOT: "otherName", + "w": "1", + "x": "2", + "y": "3", + "diff": "wrong", +} + +API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY = [ + { + STANDARD_API_NAME_SLOT: "name", + STANDARD_REQUIRED_KEY: ["a"], + STANDARD_OPTIONAL_KEY: [], + } +] + +API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS = [ + { + STANDARD_API_NAME_SLOT: "name", + STANDARD_REQUIRED_KEY: ["a"], + STANDARD_OPTIONAL_KEY: ["b", "c", "d"], + } +] + +API_SCHEMA__ONE_CALL_HARD = [ + { + STANDARD_API_NAME_SLOT: "otherName", + STANDARD_REQUIRED_KEY: ["w", "x"], + STANDARD_OPTIONAL_KEY: ["y", "z", "diff"], + } +] + + +class TodMetricsTestHelper: + def __init__(self, e: TodStructuredEpisode): + self.m = TodMetrics() + self.m.handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] # run on ALL + self.e = e + + def _process(self, t: TodAgentType, text: str): + self.m.handle_message({"text": f"{TOD_AGENT_TYPE_TO_PREFIX[t]}{text}"}, t) + + def run(self): + self._process(TodAgentType.API_SCHEMA_GROUNDING_AGENT, self.e.api_schemas_utt) + self._process(TodAgentType.GOAL_GROUNDING_AGENT, self.e.goal_calls_utt) + + for r in self.e.rounds: + self._process(TodAgentType.USER_UTT_AGENT, r.user_utt) + self._process(TodAgentType.API_CALL_AGENT, r.api_call_utt) + self._process(TodAgentType.API_RESP_AGENT, r.api_resp_utt) + self._process(TodAgentType.SYSTEM_UTT_AGENT, r.sys_utt) + + self.m.episode_reset() + + def report(self): + return self.m.report() + + +class TestApiGoalHitMetricsHandler(unittest.TestCase): + def __helper(self, api_schemas_machine, goal_calls_machine, single_turn_api_call): + e = TodStructuredEpisode( + api_schemas_machine=api_schemas_machine, + goal_calls_machine=goal_calls_machine, + rounds=[TodStructuredRound(api_call_machine=single_turn_api_call)], + ) + helper = TodMetricsTestHelper(e) + helper.run() + result = helper.report() + return result + + def test_one_goal_only_req(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY, + goal_calls_machine=GOAL__SINGLE_ONE_KEY, + single_turn_api_call=API_CALL__SINGLE_ONE_KEY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 1) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) + self.assertAlmostEqual(result["all_goals_slot_recall"], 1) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + def test_one_goal_api_name_missing_slots(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY, + goal_calls_machine=GOAL__SINGLE_ONE_KEY, + single_turn_api_call=API_CALL__VALID_NAME_BUT_EMPTY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) # api_name + self.assertAlmostEqual(result["all_goals_slot_recall"], 0.5) + + self.assertAlmostEqual(result["req_goals_hit"], 0) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 0.5) + + def test_one_goal_with_opts(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS, + goal_calls_machine=GOAL__SINGLE_THREE_KEYS, + single_turn_api_call=API_CALL__SINGLE_ONE_KEY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) + self.assertAlmostEqual(result["all_goals_slot_recall"], 0.5) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + def test_hard_case(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_HARD, + goal_calls_machine=GOAL__HARD, + single_turn_api_call=API_CALL__FUNKY_AGAINST_HARD, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 0.8) + self.assertAlmostEqual(result["all_goals_slot_recall"], 2.0 / 3.0) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 0.6) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + +class TestApiCallMalformedMetricsHandler(unittest.TestCase): + def __helper(self, single_turn_api_call): + e = TodStructuredEpisode( + api_schemas_machine=API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS, + rounds=[TodStructuredRound(api_call_machine=single_turn_api_call)], + ) + helper = TodMetricsTestHelper(e) + helper.run() + return helper.report() + + def test_no_api_name_slot(self): + result = self.__helper(API_CALL__NO_API_NAME_SLOT) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_hasSlotsButNoApiNameSlot_count"], 1) + + def test_api_name_DNE(self): + result = self.__helper(API_CALL__API_NAME_DNE) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_methodDNE_count"], 1) + + def test_missing_required_slot(self): + result = self.__helper(API_CALL__VALID_NAME_BUT_EMPTY) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_missingRequiredSlot_count"], 1) + + def test_has_single_required_slot(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY) + self.assertEqual(result["apiCall_wellFormed"], 1) + self.assertEqual(result["apiCall_wellFormed_count"], 1) + + def test_has_valid_optional_slot(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY_WITH_OPT) + self.assertEqual(result["apiCall_wellFormed"], 1) + self.assertEqual(result["apiCall_wellFormed_count"], 1) + + def test_has_invalid_extra_slots(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY_WITH_OPT_AND_NONVALID) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_hasExtraParams_count"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py new file mode 100644 index 00000000000..ffac2a25e12 --- /dev/null +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +sTests tod world, notably for batching. +""" + +import copy +import unittest + +from parlai.core.metrics import dict_report +from parlai.core.opt import Opt +from parlai.core.tod.tod_core import SerializationHelpers +import parlai.core.tod.tod_test_utils.test_agents as test_agents +from parlai.core.tod.world_metrics_handlers import ( + METRICS_HANDLER_CLASSES_TEST_REGISTRY, +) +import parlai.scripts.tod_world_script as tod_world_script + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + +NUM_EPISODES = 35 + +TEST_SETUP = { + "api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent", + "goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent", + "user_model": "parlai.core.tod.tod_test_utils.test_agents:UserUttAgent", + "system_model": "parlai.core.tod.tod_test_utils.test_agents:ApiCallAndSysUttAgent", + "api_resp_model": "fixed_response", + test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES, +} +TEST_SETUP_BROKEN_USER_SYSTEM = { + "api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent", + "goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent", + "user_model": "fixed_response", + "system_model": "fixed_response", + "api_resp_model": "fixed_response", + test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES, +} + +TEST_SETUP_EMPTY_APISCHEMA = copy.deepcopy(TEST_SETUP) +TEST_SETUP_EMPTY_APISCHEMA[ + "api_schema_grounding_model" +] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent" + +TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA = copy.deepcopy( + TEST_SETUP_BROKEN_USER_SYSTEM +) +TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA[ + "api_schema_grounding_model" +] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent" + +DATATYPE = "valid" + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def __init__(self, opt: Opt): + opt["datatype"] = DATATYPE + # none of the below matter, but need to set to keep other code happy. + opt["log_keep_fields"] = "all" + opt["display_examples"] = False + + super().__init__(opt) + + def _setup_world(self): + world = super()._setup_world() + for i in range(len(world.batch_tod_world_metrics)): + world.batch_tod_world_metrics[i].handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] + return world + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.logger = logger + self.episode_metrics = episode_metrics + + +class TodMetricsInScriptTests(unittest.TestCase): + def test_all_goals_hit_all_success(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=1, num_episodes=1, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=1, num_episodes=32, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=32, num_episodes=8, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=32, num_episodes=33, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=1, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_fail(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=1, + num_episodes=1, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=1, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=33, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=0, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_success_emptySchema(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=1, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=32, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=32, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=33, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=1, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_fail_emptySchema(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=1, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=33, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=0, + target_metrics_length=NUM_EPISODES, + ) + + def _check_all_goals_hit_by_opt_and_batchsize( + self, + opt, + batchsize, + num_episodes, + target_all_goals_hit, + target_metrics_length=None, + ): + opt = copy.deepcopy(opt) + opt["batchsize"] = batchsize + opt["num_episodes"] = num_episodes + report, metrics = self._run_opt_get_report(opt) + self.assertEqual(report.get("all_goals_hit"), target_all_goals_hit) + metrics_comp_length = num_episodes + if target_metrics_length: + metrics_comp_length = target_metrics_length + self.assertEqual(len(metrics), metrics_comp_length) + + def _run_opt_get_report(self, opt): + script = TestTodWorldScript(opt) + script.run() + + def get_episode_report(goal, episode_metric): + metrics_dict = dict_report(episode_metric.report()) + metrics_dict["goal"] = goal + return metrics_dict + + return dict_report(script.world.report()), [ + get_episode_report(g, e) for g, e in script.episode_metrics + ] + + def test_apiCallAttempts_usingGold(self): + opt = copy.deepcopy(TEST_SETUP) + opt["batchsize"] = 1 + opt["num_episodes"] = -1 + _, metrics = self._run_opt_get_report(opt) + for metric in metrics: + self.assertEqual( + len( + SerializationHelpers.str_to_goals( + metric["goal"]["text"][len("GOALS: ") :] + ) + ), + metric["call_attempts"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py new file mode 100644 index 00000000000..9862b3c824f --- /dev/null +++ b/tests/tod/test_tod_world_script_metrics.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests tod world, notably for batching. +""" + +import copy +import unittest + +import parlai.core.tod.tod_test_utils.test_agents as test_agents +import parlai.scripts.tod_world_script as tod_world_script +from parlai.core.tod.tod_agents import StandaloneApiAgent +from parlai.core.tod.world_metrics_handlers import ( + METRICS_HANDLER_CLASSES_TEST_REGISTRY, +) +from parlai.core.metrics import dict_report + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def _get_tod_agents(self, opt): + """ + Hack so we can separate out logic of making sure agent parsing is correct. + """ + if hasattr(self, "agents"): + return self.agents + return super()._get_tod_agents(opt) + + def _setup_world(self): + world = super()._setup_world() + for i in range(len(world.batch_tod_world_metrics)): + world.batch_tod_world_metrics[i].handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] + return world + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.episode_metrics = episode_metrics + + +class TodWorldInScriptTestBase(unittest.TestCase): + def add_tod_world_opts(self, base_opts): + """ + Convenience since we're initing the opt directly without parlai parser. + """ + opts = copy.deepcopy(base_opts) + opts["datatype"] = "DUMMY" + opts["datafile"] = "DUMMY" + opts["standalone_api_file"] = test_agents.API_DATABASE_FILE + opts["exact_api_call"] = True + opts["log_keep_fields"] = "all" + opts["display_examples"] = False + opts[ + "include_api_schemas" + ] = True # do this to test_agents.make sure they're done correctly. + return opts + + def setup_agents(self, added_opts): + full_opts = self.add_tod_world_opts(added_opts) + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + test_agents.ApiSchemaAgent(full_opts), + test_agents.GoalAgent(full_opts), + ] + return agents, full_opts + + def _run_test(self): + self._run_test_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE_BS) + + def _run_test_helper(self, config_base): + config = copy.deepcopy(config_base) + config["use_broken_mock_api_calls"] = True + add = self.config_args() + for key in add: + config[key] = add[key] + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_metrics_correct(script, opt) + + def _check_metrics_correct(self, script, opt): + """ + Last argument is only relevant for the max_turn test. + """ + max_rounds = opt[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] + episode_metrics = script.episode_metrics + for episode_idx, episode in enumerate(episode_metrics): + # if episode_idx >= max_episodes: + # break + # See how we make broken mock api calls in the test_agents. + goal, episode_metric = episode + episode_metric = dict_report(episode_metric.report()) + self.assertAlmostEqual( + episode_metric["all_goals_hit"], + not test_agents.episode_has_broken_api_turn(episode_idx, max_rounds), + ) + broken_episodes = sum( + [ + test_agents.episode_has_broken_api_turn(i, max_rounds) + for i in range(max_episodes) + ] + ) + report = dict_report(script.world.report()) + self.assertAlmostEqual( + report["all_goals_hit"], + float(max_episodes - broken_episodes) / max_episodes, + ) + + +class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + def config_args(self): + config = {} + config["batchsize"] = 1 + config["max_turns"] = 10 + return config + + def test_metricsCorrect(self): + self._run_test() + + +class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + def config_args(self): + config = {} + config["batchsize"] = 4 + config["max_turns"] = 10 + return config + + def test_metricsCorrect(self): + self._run_test() + + +if __name__ == "__main__": + unittest.main() From 0e3f492dfffb4746a17ce367d97e91af895dfd49 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 11:59:27 -0800 Subject: [PATCH 08/57] hmmm... hoping stacks don't bite me. (change that was kept in upper diff in stack, but lost from this one --- parlai/core/tod/tod_core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py index 76ee8005a74..e53be5c7376 100644 --- a/parlai/core/tod/tod_core.py +++ b/parlai/core/tod/tod_core.py @@ -160,7 +160,10 @@ def inner_list_join(cls, values): @classmethod def inner_list_split(cls, s): - return s.split(", ") + split = s.split(", ") + if len(split) == 1: + return split[0] + return set(split) @classmethod def maybe_inner_list_join(cls, values): @@ -193,7 +196,7 @@ def str_to_api_dict(cls, string): continue name, value = slot_str.split(" = ", 1) name = name.strip() - value = value.strip() + value = SerializationHelpers.inner_list_split(value.strip()) result[name] = value return result From 37aced299af86863069b7655ef760033dcd9a4e5 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:01:13 -0800 Subject: [PATCH 09/57] minor, remove commented out print --- parlai/core/tod/tod_test_utils/test_agents.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b1339052764..b7639efea36 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -171,7 +171,6 @@ def setup_episodes(self, _): ), ) ) - # print(result, self.opt) return result def get_id_task_prefix(self): From b05930fa1238faa1f125ae80ee93e212ce1b8ada Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:16:33 -0800 Subject: [PATCH 10/57] comment --- parlai/scripts/tod_world_script.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index af4f963b71f..81b38b2b620 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -5,6 +5,13 @@ # LICENSE file in the root directory of this source tree. """ Base script for running TOD model-model chats. + +For example, to extract gold ground truth data from Google SGD, run + +``` +python -u -m parlai.scripts.tod_world_script --api-schema-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent --goal-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent --user-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent --system-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent --api-resp-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent -dt valid --num-episodes -1 --episodes-randomization-seed 42 --world-logs gold-valid +``` + """ import json @@ -108,7 +115,7 @@ class TodWorldScript(ParlaiScript): @classmethod def setup_tod_args(cls, parser: ParlaiParser): tod_args = parser.add_argument_group( - "TOD World Script Agent arguments. NOTE: Agents setup with this path will be able to take command line arguments, whereas those set from `-o` or `--init-opt` will not." + "TOD World Script Agent arguments. NOTE: If there are issues with invoking downstream opts of agents specified here sometimes you will have more luck with `python -u -m parlai.scripts.tod_world_script` than `parlai tod_world_script`." ) tod_args.add_argument( "--system-model-file", From 5086e8507c785c53f1b78bc5060b4e7e393ab430 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:19:08 -0800 Subject: [PATCH 11/57] more comment updates (not sure if it actually helps clarity..) --- parlai/scripts/tod_world_script.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 81b38b2b620..42199bc4ff3 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -6,12 +6,11 @@ """ Base script for running TOD model-model chats. -For example, to extract gold ground truth data from Google SGD, run +For example, to extract gold ground truth data from Google SGD, run ``` python -u -m parlai.scripts.tod_world_script --api-schema-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent --goal-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent --user-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent --system-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent --api-resp-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent -dt valid --num-episodes -1 --episodes-randomization-seed 42 --world-logs gold-valid ``` - """ import json From 9a25fc56b94b7dc47961c8e04b63a3848fc073df Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:40:39 -0800 Subject: [PATCH 12/57] [TOD][Dataset][Easy] Google SGD in TOD Conversations format Refactor Google SGD away from old format into TOD Conversations format. Datasets added in this substack: * *Google SGD* * Google SGD Simulation Splits (In-domain, Out-domain) * MetalWoz * MSR_E2E * Multidogo * MultiWoz V2.2 * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/google_sgd/agents.py | 354 ++++++++++-------- parlai/tasks/google_sgd/build.py | 16 +- parlai/tasks/google_sgd/test.py | 10 +- .../google_sgd_UserSimulatorTeacher_test.yml | 51 +++ .../google_sgd_UserSimulatorTeacher_train.yml | 49 +++ .../google_sgd_UserSimulatorTeacher_valid.yml | 46 +++ .../tasks/google_sgd/test/google_sgd_test.yml | 108 ++---- .../google_sgd/test/google_sgd_train.yml | 85 ++--- .../google_sgd/test/google_sgd_valid.yml | 101 ++--- 9 files changed, 466 insertions(+), 354 deletions(-) create mode 100644 parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/google_sgd/agents.py b/parlai/tasks/google_sgd/agents.py index 12e55a5deff..bfcdf02b3dc 100644 --- a/parlai/tasks/google_sgd/agents.py +++ b/parlai/tasks/google_sgd/agents.py @@ -8,192 +8,228 @@ Google The Schema-Guided Dialogue(SGD) Dataset implementation for ParlAI. """ -import os +import glob import json -from parlai.core.opt import Opt -from parlai.core.teachers import DialogTeacher -from parlai.utils.misc import warn_once -from parlai.core.message import Message -from parlai.core.metrics import AverageMetric, BleuMetric -from parlai.utils.io import PathManager +import os +from typing import Optional import parlai.tasks.google_sgd.build as build_ +import parlai.core.tod.tod_core as tod +import parlai.core.tod.tod_agents as tod_agents +from parlai.core.tod.tod_core import SerializationHelpers +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.utils.io import PathManager -class Text2API2TextTeacher(DialogTeacher): - """ - Teacher which produces both API calls and NLG responses. - """ +class GoogleSGDParser(tod_agents.TodStructuredDataParser): + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--delex", type="bool", default=False, help="Delexicalize labels" + ) + parser.add_argument( + "--filter-dialogue-by-id", + default="", + type=str, + help="Path to a json file of `dialogue_id`s for which we will filter from. Assumes it will contain a map where the keys are a fold and the value is a list of ids", + ) + return parser def __init__(self, opt: Opt, shared=None): - self.fold = opt['datatype'].split(':')[0] - opt['datafile'] = self.fold - self.dpath = os.path.join(opt['datapath'], 'google_sgd') + self.fold = self.get_fold(opt) + opt["datafile"] = self.fold + self.dpath = os.path.join(opt["datapath"], "google_sgd") if shared is None: - warn_once( - "Google SGD is a beta dataset, and format may significantly change." - ) + # full initialize the teacher as this is not a clone build_.build(opt) super().__init__(opt, shared) + def get_fold(self, opt): + return opt["datatype"].split(":")[0] + def _load_data(self, fold): - dataset_fold = 'dev' if fold == 'valid' else fold + dataset_fold = "dev" if fold == "valid" else fold fold_path = os.path.join(self.dpath, dataset_fold) - schema_file = os.path.join(fold_path, 'schema.json') - with PathManager.open(schema_file, 'r') as f: + schema_file = os.path.join(fold_path, "schema.json") + with PathManager.open(schema_file, "r") as f: schema_lookup = {} for schema in json.load(f): - schema_lookup[schema['service_name']] = schema - - dialogs = [] - for file_id in range(1, build_.fold_size(dataset_fold) + 1): - filename = os.path.join(fold_path, f'dialogues_{file_id:03d}.json') - with PathManager.open(filename, 'r') as f: - dialogs += json.load(f) - return schema_lookup, dialogs - - def _get_api_call_and_results(self, sys_turn, schema_lookup): + schema_lookup[schema["service_name"]] = schema + + dialogues = [] + for filename in glob.glob(f"{fold_path}/dialogues*.json"): + with PathManager.open(filename, "r") as f: + dialogues += json.load(f) + + filter_path = self.opt.get("filter_dialogue_by_id", "") + if len(filter_path) > 0: + filtered = [] + with open(filter_path) as f: + dialogues_to_get = json.load(f)[fold] + for dialogue in dialogues: + if dialogue["dialogue_id"] in dialogues_to_get: + filtered.append(dialogue) + assert len(filtered) == len( + dialogues_to_get + ), f"Different number of dialogues found than requested. Are you sure you've got the right form of Google SGD? Did you filter for dialogue ids correctly? len(filtered) = {len(filtered)}, len(dialogues_to_get) = {len(dialogues_to_get)}" + dialogues = filtered + return schema_lookup, dialogues + + def _get_api_call_and_results(self, sys_turn): api_call = {} api_resp = {} - for frame in sys_turn['frames']: - if 'service_call' in frame: + for frame in sys_turn["frames"]: + if "service_call" in frame: # API CALL - method = frame['service_call']['method'] - for slot_type, slot_value in frame['service_call'][ - 'parameters' + for slot_type, slot_value in frame["service_call"][ + "parameters" ].items(): - api_call[f'{method}.{slot_type}'] = slot_value - assert 'service_results' in frame + if slot_value: + api_call[ + f"{slot_type.strip()}" + ] = SerializationHelpers.inner_list_join(slot_value) + api_call[tod.STANDARD_API_NAME_SLOT] = frame["service_call"]["method"] + assert "service_results" in frame # API Resp - if 'actions' in frame: - for action in frame['actions']: - slot_type = action['slot'] - slot_value = action['canonical_values'] - api_resp[slot_type] = slot_value + if "service_results" in frame: + api_resp = {} + service_results = frame["service_results"] + if len(service_results) > 0: + for key, value in service_results[0].items(): + api_resp[key] = SerializationHelpers.inner_list_join(value) return api_call, api_resp - def custom_evaluation( - self, teacher_action: Message, labels, model_response: Message - ): - resp = model_response.get('text') - if not resp: - return - - if teacher_action['type'] == 'apicall' and resp.startswith('apicall: '): - gold = teacher_action['slots'] - slot_strs = resp[9:].split(' ; ') - parsed = {} - for slot_str in slot_strs: - if ' = ' not in slot_str: - if slot_str != '': - # syntactically invalid generations should count against us - self.metrics.add('slot_p', AverageMetric(0)) - continue - name, value = slot_str.split(' = ') - parsed[name] = value - - # slot precision - for k, v in parsed.items(): - self.metrics.add('slot_p', AverageMetric(v == gold.get(k))) - # slot recall - for k, v in gold.items(): - self.metrics.add('slot_r', AverageMetric(v == parsed.get(k))) - elif teacher_action['type'] == 'apiresp': - delex_resp = self._delex(resp, teacher_action['slots']) - delex_label = self._delex(labels[0], teacher_action['slots']) - self.metrics.add( - 'delex_bleu', BleuMetric.compute(delex_resp, [delex_label]) - ) - - def _delex(self, text, slots): - delex = text - for slot, values in slots.items(): - assert isinstance(values, list) - for value in values: - delex = delex.replace(value, slot) - return delex - - def _api_dict_to_str(self, apidict): - return ' ; '.join(f'{k} = {v}' for k, v in apidict.items()) - - def setup_data(self, fold): - schema_lookup, dialogs = self._load_data(fold) - for dialog in dialogs: - # services = dialog['services'] - turns = dialog['turns'] - num_turns = len(turns) - for turn_id in range(0, num_turns, 2): - is_first_turn = turn_id == 0 - + def _get_apis_in_domain(self, schema, domain): + """ + Google SGD includes extra information with the call, so remove these. + """ + result = {} + for intent in schema[domain].get("intents", {}): + here = {} + if "required_slots" in intent and len(intent["required_slots"]) > 0: + here[tod.STANDARD_REQUIRED_KEY] = intent["required_slots"] + if "optional_slots" in intent and len(intent["optional_slots"]) > 0: + here[tod.STANDARD_OPTIONAL_KEY] = intent["optional_slots"] + if "result_slots" in intent: + here["results"] = intent["result_slots"] + result[intent["name"]] = here + return result + + def _get_intent_groundinging(self, schema, domains): + """ + Returns map where keys are intents and values are names of required/optional + slots. + + We do not care about `result_slots` or default values of optional slots. + """ + result = [] + for domain in domains: + apis = self._get_apis_in_domain(schema, domain) + for intent, params in apis.items(): + here = {} + here[tod.STANDARD_API_NAME_SLOT] = intent + if tod.STANDARD_REQUIRED_KEY in params: + here[tod.STANDARD_REQUIRED_KEY] = params[tod.STANDARD_REQUIRED_KEY] + if ( + tod.STANDARD_OPTIONAL_KEY in params + and len(params[tod.STANDARD_OPTIONAL_KEY]) > 0 + ): + here[tod.STANDARD_OPTIONAL_KEY] = params[ + tod.STANDARD_OPTIONAL_KEY + ].keys() + result.append(here) + return result + + def _get_all_service_calls(self, turns): + """ + Searches through all turns in a dialogue for any service calls, returns these. + """ + results = [] + for turn in turns: + for frame in turn["frames"]: + if "service_call" in frame: + call = frame["service_call"] + item = call["parameters"] + item[tod.STANDARD_API_NAME_SLOT] = call["method"] + results.append(item) + return results + + def setup_episodes(self, fold): + """ + Parses Google SGD episodes into TodStructuredEpisode. + """ + schema_lookup, dialogues = self._load_data(fold) + result = [] + for dialogue in dialogues: + domains = {s.split("_")[0].strip() for s in dialogue["services"]} + turns = dialogue["turns"] + rounds = [] + for turn_id in range(0, len(turns), 2): user_turn = turns[turn_id] sys_turn = turns[turn_id + 1] - api_call, api_results = self._get_api_call_and_results( - sys_turn, schema_lookup + api_call, api_results = self._get_api_call_and_results(sys_turn) + r = tod.TodStructuredRound( + user_utt=user_turn["utterance"], + api_call_machine=api_call, + api_resp_machine=api_results, + sys_utt=sys_turn["utterance"], ) - call_str = self._api_dict_to_str(api_call) - resp_str = self._api_dict_to_str(api_results) - if not api_call and not api_results: - # input: user_turn, output: sys_turn - yield { - 'text': user_turn['utterance'], - 'label': sys_turn['utterance'], - 'type': 'text', - }, is_first_turn - elif not api_call and api_results: - yield { - 'text': f"{user_turn['utterance']} api_resp: {resp_str}", - 'label': sys_turn['utterance'], - 'type': 'apiresp', - 'slots': api_results, - }, is_first_turn - elif api_call and api_results: - # input: user_turn, output: api_call - yield { - 'text': user_turn['utterance'], - 'label': f'apicall: {call_str}', - 'type': 'apicall', - 'slots': api_call, - }, is_first_turn - - # system turn, input : api results, output : assistant turn - yield { - 'text': f"api_resp: {resp_str}", - 'label': sys_turn['utterance'], - 'type': 'apiresp', - 'slots': api_results, - }, False - else: - assert ( - api_call and api_results - ), "API call without API results! Check Dataset!" - - -class Text2TextTeacher(Text2API2TextTeacher): - """ - Text-only teacher (with no API calls or slots) - """ - - def setup_data(self, fold): - schema_lookup, dialogs = self._load_data(fold) - for dialog in dialogs: - turns = dialog['turns'] - num_turns = len(turns) - for turn_id in range(0, num_turns, 2): - if turn_id == 0: - is_first_turn = True - else: - is_first_turn = False + rounds.append(r) + # Now that we've got the rounds, make the episode + episode = tod.TodStructuredEpisode( + domain=SerializationHelpers.inner_list_join(domains), + api_schemas_machine=self._get_intent_groundinging( + schema_lookup, set(dialogue["services"]) + ), + goal_calls_machine=self._get_all_service_calls(turns), + rounds=rounds, + delex=self.opt.get("delex"), + extras={"dialogue_id": dialogue["dialogue_id"]}, + ) + result.append(episode) + # check if the number of episodes should be limited and truncate as required + return result - user_turn = turns[turn_id] - sys_turn = turns[turn_id + 1] - # input: user_turn, output: sys_turn - yield { - 'text': user_turn['utterance'], - 'label': sys_turn['utterance'], - 'type': 'text', - }, is_first_turn + def get_id_task_prefix(self): + return "GoogleSGD" + + +class SystemTeacher(GoogleSGDParser, tod_agents.TodSystemTeacher): + pass + + +class DefaultTeacher(SystemTeacher): + pass + + +class UserSimulatorTeacher(GoogleSGDParser, tod_agents.TodUserSimulatorTeacher): + pass + + +class StandaloneApiTeacher(GoogleSGDParser, tod_agents.TodStandaloneApiTeacher): + pass + + +class SingleGoalAgent(GoogleSGDParser, tod_agents.TodSingleGoalAgent): + pass + + +class GoalAgent(GoogleSGDParser, tod_agents.TodGoalAgent): + pass + + +class ApiSchemaAgent(GoogleSGDParser, tod_agents.TodApiSchemaAgent): + pass + + +class UserUttAgent(GoogleSGDParser, tod_agents.TodUserUttAgent): + pass -class DefaultTeacher(Text2API2TextTeacher): +class ApiCallAndSysUttAgent(GoogleSGDParser, tod_agents.TodApiCallAndSysUttAgent): pass diff --git a/parlai/tasks/google_sgd/build.py b/parlai/tasks/google_sgd/build.py index d0432ed1241..d0bf47adce2 100644 --- a/parlai/tasks/google_sgd/build.py +++ b/parlai/tasks/google_sgd/build.py @@ -8,7 +8,7 @@ import parlai.core.build_data as build_data import os -ROOT_URL = 'https://github.com/google-research-datasets/dstc8-schema-guided-dialogue/raw/master' +ROOT_URL = "https://github.com/google-research-datasets/dstc8-schema-guided-dialogue/raw/master" DATA_LEN = {"train": 127, "dev": 20, "test": 34} @@ -18,13 +18,13 @@ def fold_size(fold): def build(opt): # get path to data directory - dpath = os.path.join(opt['datapath'], 'google_sgd') + dpath = os.path.join(opt["datapath"], "google_sgd") # define version if any version = "1.0" # check if data had been previously built if not build_data.built(dpath, version_string=version): - print('[building data: ' + dpath + ']') + print("[building data: " + dpath + "]") # make a clean directory if needed if build_data.built(dpath): @@ -33,16 +33,16 @@ def build(opt): build_data.make_dir(dpath) # Download the data. - for split_type in ['train', 'dev', 'test']: + for split_type in ["train", "dev", "test"]: outpath = os.path.join(dpath, split_type) - filename = 'schema.json' - url = f'{ROOT_URL}/{split_type}/{filename}' + filename = "schema.json" + url = f"{ROOT_URL}/{split_type}/{filename}" build_data.make_dir(outpath) build_data.download(url, outpath, filename) for file_id in range(1, DATA_LEN[split_type] + 1): - filename = f'dialogues_{file_id:03d}.json' - url = f'{ROOT_URL}/{split_type}/{filename}' + filename = f"dialogues_{file_id:03d}.json" + url = f"{ROOT_URL}/{split_type}/{filename}" build_data.download(url, outpath, filename) # mark the data as built diff --git a/parlai/tasks/google_sgd/test.py b/parlai/tasks/google_sgd/test.py index e348c087c61..5ea8078f74e 100644 --- a/parlai/tasks/google_sgd/test.py +++ b/parlai/tasks/google_sgd/test.py @@ -7,13 +7,9 @@ from parlai.utils.testing import AutoTeacherTest -class DisabledTestDefaultTeacher(AutoTeacherTest): +class TestDefaultTeacher(AutoTeacherTest): task = "google_sgd" -class DisabledTestText2API2TextTeacher(AutoTeacherTest): - task = "google_sgd:text2_a_p_i2_text" - - -class DisabledTestText2TextTeacher(AutoTeacherTest): - task = "google_sgd:text2_text" +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "google_sgd:UserSimulatorTeacher" diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..dad04075246 --- /dev/null +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml @@ -0,0 +1,51 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Hi, could you get me a restaurant booking on the 8th please?' + id: GoogleSGD_UserSimulatorTeacher + text: 'GOAL: api_name = ReserveRestaurant ; date = 2019-03-08 ; location = Corte + Madera ; number_of_seats = 2 ; restaurant_name = P.f. Chang''s ; time = 12:00 + | api_name = ReserveRestaurant ; date = 2019-03-08 ; location = Corte Madera + ; number_of_seats = 2 ; restaurant_name = Benissimo Restaurant & Bar ; time + = 12:00' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Could you get me a reservation at P.f. Chang''s in Corte Madera at afternoon + 12?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Any preference on the restaurant, location and time?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Sure, that is great.' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Please confirm your reservation at P.f. Chang''s in Corte Madera + at 12 pm for 2 on March 8th.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Could you try booking a table at Benissimo instead?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sorry, your reservation could not be made. Could I help you with + something else?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Sure, may I know if they have vegetarian options and how expensive is + their food?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sure, please confirm your reservation at Benissimo Restaurant & + Bar in Corte Madera at 12 pm for 2 on March 8th.' + type: 'USER: ' +num_episodes: 4201 +num_examples: 97197 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..fca91a3b09f --- /dev/null +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml @@ -0,0 +1,49 @@ +acts: +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: I am feeling hungry so I would like to find a place to eat.' + text: 'GOAL: api_name = FindRestaurants ; city = San Jose ; cuisine = American + | api_name = FindRestaurants ; city = Palo Alto ; cuisine = American ; price_range + = moderate | api_name = ReserveRestaurant ; city = Palo Alto ; date = 2019-03-01 + ; party_size = 2 ; restaurant_name = Bird Dog ; time = 11:30' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: I would like for it to be in San Jose.' + slots: {} + text: 'SYSTEM: Do you have a specific which you want the eating place to be located + at?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: I usually like eating the American type of food.' + slots: {} + text: 'SYSTEM: Is there a specific cuisine type you enjoy, such as Mexican, Italian + or something else?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: Can you give me the address of this restaurant.' + slots: {} + text: 'SYSTEM: I see that at 71 Saint Peter there is a good restaurant which is + in San Jose.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: Can you give me the phone number that I can contact them with?' + slots: {} + text: 'SYSTEM: If you want to go to this restaurant you can find it at 71 North + San Pedro Street.' + type: 'USER: ' +num_episodes: 16142 +num_examples: 378390 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..e9a511d9519 --- /dev/null +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml @@ -0,0 +1,46 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: I want to make a restaurant reservation for 2 people at half past 11 + in the morning.' + id: GoogleSGD_UserSimulatorTeacher + text: 'GOAL: api_name = ReserveRestaurant ; date = 2019-03-01 ; location = San + Jose ; number_of_seats = 2 ; restaurant_name = Sino ; time = 11:30' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Please find restaurants in San Jose. Can you try Sino?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: What city do you want to dine in? Do you have a preferred restaurant?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Yes, thanks. What''s their phone number?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Confirming: I will reserve a table for 2 people at Sino in San + Jose. The reservation time is 11:30 am today.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: What''s their address? Do they have vegetarian options on their menu?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Your reservation has been made. Their phone number is 408-247-8880.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Thanks very much.' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: The street address is 377 Santana Row #1000. They have good vegetarian + options.' + type: 'USER: ' +num_episodes: 2482 +num_examples: 56172 diff --git a/parlai/tasks/google_sgd/test/google_sgd_test.yml b/parlai/tasks/google_sgd/test/google_sgd_test.yml index 1d66bb9f0bb..4d8e4e2578e 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_test.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_test.yml @@ -1,77 +1,45 @@ acts: -- - episode_done: false +- - domain: Restaurants + episode_done: false eval_labels: - - Any preference on the restaurant, location and time? - id: google_sgd - slots: - location: [] - restaurant_name: [] - time: [] - text: 'Hi, could you get me a restaurant booking on the 8th please? api_resp: - time = [] ; restaurant_name = [] ; location = []' - type: apiresp -- - episode_done: false + - 'APIS: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Please confirm your reservation at P.f. Chang's in Corte Madera at 12 pm for - 2 on March 8th. - id: google_sgd - slots: - date: - - '2019-03-08' - location: - - Corte Madera - number_of_seats: - - '2' - restaurant_name: - - P.f. Chang's - time: - - '12:00' - text: 'Could you get me a reservation at P.f. Chang''s in Corte Madera at afternoon - 12? api_resp: restaurant_name = ["P.f. Chang''s"] ; location = [''Corte Madera''] - ; time = [''12:00''] ; date = [''2019-03-08''] ; number_of_seats = [''2'']' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: Hi, could you get me a restaurant booking on the 8th please?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'apicall: ReserveRestaurant.date = 2019-03-08 ; ReserveRestaurant.location = - Corte Madera ; ReserveRestaurant.number_of_seats = 2 ; ReserveRestaurant.restaurant_name - = P.f. Chang''s ; ReserveRestaurant.time = 12:00' - id: google_sgd - slots: - ReserveRestaurant.date: '2019-03-08' - ReserveRestaurant.location: Corte Madera - ReserveRestaurant.number_of_seats: '2' - ReserveRestaurant.restaurant_name: P.f. Chang's - ReserveRestaurant.time: '12:00' - text: Sure, that is great. - type: apicall -- - episode_done: false + - 'SYSTEM: Any preference on the restaurant, location and time?' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Sorry, your reservation could not be made. Could I help you with something else? - id: google_sgd - slots: - ? '' - : [] - text: 'api_resp: = []' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: Could you get me a reservation at P.f. Chang''s in Corte Madera at + afternoon 12?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Sure, please confirm your reservation at Benissimo Restaurant & Bar in Corte - Madera at 12 pm for 2 on March 8th. - id: google_sgd - slots: - date: - - '2019-03-08' - location: - - Corte Madera - number_of_seats: - - '2' - restaurant_name: - - Benissimo Restaurant & Bar - time: - - '12:00' - text: 'Could you try booking a table at Benissimo instead? api_resp: restaurant_name - = [''Benissimo Restaurant & Bar''] ; location = [''Corte Madera''] ; time = - [''12:00''] ; date = [''2019-03-08''] ; number_of_seats = [''2'']' - type: apiresp + - 'SYSTEM: Please confirm your reservation at P.f. Chang''s in Corte Madera at + 12 pm for 2 on March 8th.' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' num_episodes: 4201 -num_examples: 54970 +num_examples: 97197 diff --git a/parlai/tasks/google_sgd/test/google_sgd_train.yml b/parlai/tasks/google_sgd/test/google_sgd_train.yml index cd2b3a48e30..596025f82da 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_train.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_train.yml @@ -1,54 +1,45 @@ acts: -- - episode_done: false - id: google_sgd +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - Do you have a specific which you want the eating place to be located at? - slots: - city: [] - text: 'I am feeling hungry so I would like to find a place to eat. api_resp: city - = []' - type: apiresp -- - episode_done: false - id: google_sgd + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - Is there a specific cuisine type you enjoy, such as Mexican, Italian or something - else? - slots: - cuisine: - - Mexican - - Italian - text: 'I would like for it to be in San Jose. api_resp: cuisine = [''Mexican'', - ''Italian'']' - type: apiresp -- - episode_done: false - id: google_sgd + - 'APICALL: ' + slots: {} + text: 'USER: I am feeling hungry so I would like to find a place to eat.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - 'apicall: FindRestaurants.city = San Jose ; FindRestaurants.cuisine = American' - slots: - FindRestaurants.city: San Jose - FindRestaurants.cuisine: American - text: I usually like eating the American type of food. - type: apicall -- - episode_done: false - id: google_sgd + - 'SYSTEM: Do you have a specific which you want the eating place to be located + at?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - I see that at 71 Saint Peter there is a good restaurant which is in San Jose. - slots: - city: - - San Jose - restaurant_name: - - 71 Saint Peter - text: 'api_resp: restaurant_name = [''71 Saint Peter''] ; city = [''San Jose'']' - type: apiresp -- - episode_done: false - id: google_sgd + - 'APICALL: ' + slots: {} + text: 'USER: I would like for it to be in San Jose.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - If you want to go to this restaurant you can find it at 71 North San Pedro Street. - slots: - street_address: - - 71 North San Pedro Street - text: 'Can you give me the address of this restaurant. api_resp: street_address - = [''71 North San Pedro Street'']' - type: apiresp + - 'SYSTEM: Is there a specific cuisine type you enjoy, such as Mexican, Italian + or something else?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' num_episodes: 16142 -num_examples: 215128 +num_examples: 378390 diff --git a/parlai/tasks/google_sgd/test/google_sgd_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_valid.yml index 8c95cc85c06..2cd09f89ce7 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_valid.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_valid.yml @@ -1,70 +1,45 @@ acts: -- - episode_done: false +- - domain: Restaurants + episode_done: false eval_labels: - - What city do you want to dine in? Do you have a preferred restaurant? - id: google_sgd - slots: - location: [] - restaurant_name: [] - text: 'I want to make a restaurant reservation for 2 people at half past 11 in - the morning. api_resp: restaurant_name = [] ; location = []' - type: apiresp -- - episode_done: false + - 'APIS: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'Confirming: I will reserve a table for 2 people at Sino in San Jose. The reservation - time is 11:30 am today.' - id: google_sgd - slots: - date: - - '2019-03-01' - location: - - San Jose - number_of_seats: - - '2' - restaurant_name: - - Sino - time: - - '11:30' - text: 'Please find restaurants in San Jose. Can you try Sino? api_resp: restaurant_name - = [''Sino''] ; location = [''San Jose''] ; time = [''11:30''] ; number_of_seats - = [''2''] ; date = [''2019-03-01'']' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: I want to make a restaurant reservation for 2 people at half past + 11 in the morning.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'apicall: ReserveRestaurant.date = 2019-03-01 ; ReserveRestaurant.location = - San Jose ; ReserveRestaurant.number_of_seats = 2 ; ReserveRestaurant.restaurant_name - = Sino ; ReserveRestaurant.time = 11:30' - id: google_sgd - slots: - ReserveRestaurant.date: '2019-03-01' - ReserveRestaurant.location: San Jose - ReserveRestaurant.number_of_seats: '2' - ReserveRestaurant.restaurant_name: Sino - ReserveRestaurant.time: '11:30' - text: Yes, thanks. What's their phone number? - type: apicall -- - episode_done: false + - 'SYSTEM: What city do you want to dine in? Do you have a preferred restaurant?' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Your reservation has been made. Their phone number is 408-247-8880. - id: google_sgd - slots: - ? '' - : [] - phone_number: - - 408-247-8880 - text: 'api_resp: phone_number = [''408-247-8880''] ; = []' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: Please find restaurants in San Jose. Can you try Sino?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'The street address is 377 Santana Row #1000. They have good vegetarian options.' - id: google_sgd - slots: - address: - - '377 Santana Row #1000' - has_vegetarian_options: - - 'True' - text: 'What''s their address? Do they have vegetarian options on their menu? api_resp: - has_vegetarian_options = [''True''] ; address = [''377 Santana Row #1000'']' - type: apiresp + - 'SYSTEM: Confirming: I will reserve a table for 2 people at Sino in San Jose. + The reservation time is 11:30 am today.' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' num_episodes: 2482 -num_examples: 31825 +num_examples: 56172 From faa2356e817094339634f6f2e409399316814661 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:45:06 -0800 Subject: [PATCH 13/57] [TOD][Dataset][Easyish] Google Simulation Splits Code for processing Google SGD into In-domain and Out-domain data via `build.py`, using via agents. Datasets added in this substack: * Google SGD * **Google SGD Simulation Splits (In-domain, Out-domain)** * MetalWoz * MSR_E2E * Multidogo * MultiWoz V2.2 * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- .../google_sgd_simulation_splits/README.md | 8 + .../google_sgd_simulation_splits/__init__.py | 5 + .../google_sgd_simulation_splits/agents.py | 221 ++++++++++++++++++ .../google_sgd_simulation_splits/build.py | 138 +++++++++++ .../google_sgd_simulation_splits/test.py | 23 ++ ...tion_splits_InDomainSystemTeacher_test.yml | 45 ++++ ...ion_splits_InDomainSystemTeacher_train.yml | 45 ++++ ...ion_splits_InDomainSystemTeacher_valid.yml | 45 ++++ ...lits_InDomainUserSimulatorTeacher_test.yml | 51 ++++ ...its_InDomainUserSimulatorTeacher_train.yml | 49 ++++ ...its_InDomainUserSimulatorTeacher_valid.yml | 46 ++++ ...ion_splits_OutDomainSystemTeacher_test.yml | 54 +++++ ...on_splits_OutDomainSystemTeacher_train.yml | 63 +++++ ...on_splits_OutDomainSystemTeacher_valid.yml | 43 ++++ ...its_OutDomainUserSimulatorTeacher_test.yml | 52 +++++ ...ts_OutDomainUserSimulatorTeacher_train.yml | 49 ++++ ...ts_OutDomainUserSimulatorTeacher_valid.yml | 47 ++++ 17 files changed, 984 insertions(+) create mode 100644 parlai/tasks/google_sgd_simulation_splits/README.md create mode 100644 parlai/tasks/google_sgd_simulation_splits/__init__.py create mode 100644 parlai/tasks/google_sgd_simulation_splits/agents.py create mode 100644 parlai/tasks/google_sgd_simulation_splits/build.py create mode 100644 parlai/tasks/google_sgd_simulation_splits/test.py create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/google_sgd_simulation_splits/README.md b/parlai/tasks/google_sgd_simulation_splits/README.md new file mode 100644 index 00000000000..d911207640f --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/README.md @@ -0,0 +1,8 @@ +# The Schema-Guided Dialogue Dataset + +Originally from the +[Google Research +Datasets](https://github.com/google-research-datasets/dstc8-schema-guided-dialogue/blob/master/README.md). +See that page for details. + +This has two custom splits: one where Messaging, Payments, Home Search, and Rental Cars have also been fully extracted out of Google SGD (but is otherwise the same) and one that includes these extracted domains. diff --git a/parlai/tasks/google_sgd_simulation_splits/__init__.py b/parlai/tasks/google_sgd_simulation_splits/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/tasks/google_sgd_simulation_splits/agents.py b/parlai/tasks/google_sgd_simulation_splits/agents.py new file mode 100644 index 00000000000..459dc5e81c1 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/agents.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Google The Schema-Guided Dialogue(SGD) Dataset implementation for ParlAI. +""" + +import os + +import parlai.tasks.google_sgd_simulation_splits.build as build_ +import parlai.core.tod.tod_core as tod +from parlai.core.metrics import AverageMetric +from parlai.core.message import Message +from parlai.core.params import ParlaiParser +import parlai.core.tod.tod_agents as tod_agents +from parlai.core.opt import Opt +from parlai.tasks.google_sgd.agents import GoogleSGDParser + +from typing import List, Optional + + +class GoogleSgdInDomainParser(GoogleSGDParser): + """ + Overrides `__init__` and `_load_data` so that we grab our examples from our + separately constructed custom splits. + """ + + def __init__(self, opt: Opt, shared=None): + if shared is None: + # full initialize the teacher as this is not a clone + build_.build(opt) + super().__init__(opt, shared) + + def _load_data(self, fold): + # Best to override here because of init order + self.dpath = os.path.join(self.opt["datapath"], "google_sgd_rl_splits") + return super()._load_data(fold) + + def get_id_task_prefix(self): + return "GoogleSgdInDomain" + + +class InDomainSystemTeacher(GoogleSgdInDomainParser, tod_agents.TodSystemTeacher): + pass + + +class InDomainUserSimulatorTeacher( + GoogleSgdInDomainParser, tod_agents.TodUserSimulatorTeacher +): + pass + + +class InDomainSingleGoalAgent(GoogleSgdInDomainParser, tod_agents.TodSingleGoalAgent): + pass + + +class InDomainSingleApiSchemaAgent( + GoogleSgdInDomainParser, tod_agents.TodSingleApiSchemaAgent +): + pass + + +class InDomainGoalAgent(GoogleSgdInDomainParser, tod_agents.TodGoalAgent): + pass + + +class InDomainApiSchemaAgent(GoogleSgdInDomainParser, tod_agents.TodApiSchemaAgent): + pass + + +class InDomainUserUttAgent(GoogleSgdInDomainParser, tod_agents.TodUserUttAgent): + pass + + +class InDomainApiCallAndSysUttAgent( + GoogleSgdInDomainParser, tod_agents.TodApiCallAndSysUttAgent +): + pass + + +VALID_OUT_DOMAIN_API_NAMES = [ + "ShareLocation", + "RequestPayment", + "MakePayment", + "FindApartment", + "ScheduleVisit", + "FindHomeByArea", + "GetCarsAvailable", + "ReserveCar", +] + + +class GoogleSgdOutDomainParser(GoogleSGDParser): + """ + Overrides `__init__` and `_load_data` so that we grab our examples from our + separately constructed custom splits. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + super().add_cmdline_args(parser, partial_opt) + group = parser.add_argument_group("Google SGD Out Domain Parser") + group.add_argument( + "--filter-single-goal-episodes", + type=bool, + default=False, + help="Filter for only conversations where the original has single goals", + ) + + def __init__(self, opt: Opt, shared=None): + if shared is None: + # full initialize the teacher as this is not a clone + build_.build(opt) + super().__init__(opt, shared) + + def _load_data(self, fold): + # Best to override here because of init order + self.dpath = os.path.join( + self.opt["datapath"], "google_sgd_rl_splits/model_model_splits" + ) + return super()._load_data(fold) + + def get_id_task_prefix(self): + return "GoogleSgdOutDomain" + + def filter_goals(self, goals): + """ + Used in single goal/api schema agents only. + """ + result = [] + for goal in goals: + if goal["api_name"] in VALID_OUT_DOMAIN_API_NAMES: + result.append(goal) + return result + + def generate_episodes(self) -> List[tod.TodStructuredEpisode]: + data = super().generate_episodes() + if self.opt.get("filter_single_goal_episodes"): + result = [] + for episode in data: + if len(episode.goal_calls_machine) == 1: + result.append(episode) + data = result + return data + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + super().custom_evaluation(teacher_action, labels, model_response) + resp = model_response.get("text") + if not resp: + return + + if ( + teacher_action["type"] == tod.STANDARD_CALL + and tod.STANDARD_API_NAME_SLOT in teacher_action["slots"] + and teacher_action["slots"][tod.STANDARD_API_NAME_SLOT] + in VALID_OUT_DOMAIN_API_NAMES + ): + if resp.startswith(tod.STANDARD_CALL): + resp = resp[len(tod.STANDARD_CALL) :] + predicted = tod.SerializationHelpers.str_to_api_dict(resp) + self.metrics.add( + f"OutDomainOnlyApis/jga", + AverageMetric(teacher_action["slots"] == predicted), + ) + + +class OutDomainStandaloneApiTeacher( + GoogleSgdOutDomainParser, tod_agents.TodStandaloneApiTeacher +): + pass + + +class OutDomainSystemTeacher(GoogleSgdOutDomainParser, tod_agents.TodSystemTeacher): + pass + + +class OutDomainUserSimulatorTeacher( + GoogleSgdOutDomainParser, tod_agents.TodUserSimulatorTeacher +): + pass + + +class OutDomainGoalAgent(GoogleSgdOutDomainParser, tod_agents.TodGoalAgent): + pass + + +class OutDomainApiSchemaAgent(GoogleSgdOutDomainParser, tod_agents.TodApiSchemaAgent): + pass + + +class OutDomainSingleGoalAgent(GoogleSgdOutDomainParser, tod_agents.TodSingleGoalAgent): + pass + + +class OutDomainSingleApiSchemaAgent( + GoogleSgdOutDomainParser, tod_agents.TodSingleApiSchemaAgent +): + pass + + +class OutDomainUserUttAgent(GoogleSgdOutDomainParser, tod_agents.TodUserUttAgent): + pass + + +class OutDomainApiCallAndSysUttAgent( + GoogleSgdOutDomainParser, tod_agents.TodApiCallAndSysUttAgent +): + pass + + +class OutDomainApiResponseAgent( + GoogleSgdOutDomainParser, tod_agents.TodApiResponseAgent +): + pass diff --git a/parlai/tasks/google_sgd_simulation_splits/build.py b/parlai/tasks/google_sgd_simulation_splits/build.py new file mode 100644 index 00000000000..a62fa25086e --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/build.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import parlai.tasks.google_sgd.build as original_google_sgd_build +import parlai.core.build_data as build_data + +import os +import json +import random +from shutil import copyfile + +DATA_LEN = original_google_sgd_build.DATA_LEN + +MODEL_MODEL_HOLDOUT_DOMAINS = [ + "Homes_1", + "Homes_2", + "FindHomeByArea", + "RentalCars_1", + "RentalCars_2", + "RentalCars_3", + "Messaging_1", + "Payment_1", +] + + +def build(opt): + # get path to data directory + dpath = os.path.join(opt["datapath"], "google_sgd_rl_splits") + # define version if any + version = "1.0" + + # check if data had been previously built + if not build_data.built(dpath, version_string=version): + print("[building data: " + dpath + "]") + + # Grab things from the original Google SGD + original_google_sgd_build.build(opt) + + # make a clean directory if needed + if build_data.built(dpath): + # an older version exists, so remove these outdated files. + build_data.remove_dir(dpath) + build_data.make_dir(dpath) + + model_model_convos = [] + model_model_schemas = {} + tot_count = 0 + for split_type in ["train", "dev", "test"]: + outpath = os.path.join(dpath, split_type) + os.makedirs(outpath, exist_ok=True) + original_path = os.path.join(opt["datapath"], "google_sgd", split_type) + + copyfile( + os.path.join(original_path, "schema.json"), + os.path.join(outpath, "schema.json"), + ) + with open(os.path.join(original_path, "schema.json")) as f: + schemas = json.load(f) + for schema in schemas: + model_model_schemas[schema["service_name"]] = schema + + for file_id in range(1, DATA_LEN[split_type] + 1): + filename = f"dialogues_{file_id:03d}.json" + original_file = os.path.join(original_path, filename) + with open(original_file) as f: + blobs = json.load(f) + save_data = [] + for blob in blobs: + blob[ + "dialogue_id" + ] = f"{blob['dialogue_id']}_from_{split_type}_{file_id:03d}" + in_model_model = False + for service in blob["services"]: + if service in MODEL_MODEL_HOLDOUT_DOMAINS: + in_model_model = True + tot_count += 1 + if in_model_model: + model_model_convos.append(blob) + else: + save_data.append(blob) + with open(os.path.join(outpath, filename), "w+") as f: + json.dump(save_data, f, indent=4) + print(split_type, filename) + + print("done processing train + dev + test ") + # deal with custom splits + print( + "number of samples in homes + rental cars + messaging + payments", + len(model_model_convos), + ) + print("service usage count of above domains", tot_count) + model_model_path = os.path.join(dpath, "model_model_splits") + os.makedirs(model_model_path, exist_ok=True) + random.Random(42).shuffle(model_model_convos) + + def save_model_model(convos, split_type, model_model_path, schema): + os.makedirs(os.path.join(model_model_path, split_type), exist_ok=True) + for i in range(int(len(convos) / 64) + 1): + with open( + os.path.join( + model_model_path, split_type, f"dialogues_{i:03d}.json" + ), + "w+", + ) as f: + json.dump(convos[i * 64 : (i + 1) * 64], f, indent=4) + with open( + os.path.join(model_model_path, split_type, "schema.json"), "w+" + ) as f: + json.dump(list(schema.values()), f, indent=4) + + save_model_model( + model_model_convos[: int(0.6 * len(model_model_convos))], + "train", + model_model_path, + model_model_schemas, + ) + save_model_model( + model_model_convos[ + int(0.6 * len(model_model_convos)) : int(0.8 * len(model_model_convos)) + ], + "dev", + model_model_path, + model_model_schemas, + ) + save_model_model( + model_model_convos[int(0.8 * len(model_model_convos)) :], + "test", + model_model_path, + model_model_schemas, + ) + + print("done processing test") + + # mark the data as built + build_data.mark_done(dpath, version_string=version) diff --git a/parlai/tasks/google_sgd_simulation_splits/test.py b/parlai/tasks/google_sgd_simulation_splits/test.py new file mode 100644 index 00000000000..b98d338bce8 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from parlai.utils.testing import AutoTeacherTest + + +class TestInDomainSystemTeacher(AutoTeacherTest): + task = "google_sgd_simulation_splits:InDomainSystemTeacher" + + +class TestInDomainUserSimulatorTeacher(AutoTeacherTest): + task = "google_sgd_simulation_splits:InDomainUserSimulatorTeacher" + + +class TestOutDomainSystemTeacher(AutoTeacherTest): + task = "google_sgd_simulation_splits:OutDomainSystemTeacher" + + +class TestOutDomainUserSimulatorTeacher(AutoTeacherTest): + task = "google_sgd_simulation_splits:OutDomainUserSimulatorTeacher" diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml new file mode 100644 index 00000000000..d61223204be --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml @@ -0,0 +1,45 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'APIS: ' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'APICALL: ' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'USER: Hi, could you get me a restaurant booking on the 8th please?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'SYSTEM: Any preference on the restaurant, location and time?' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'APICALL: ' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'USER: Could you get me a reservation at P.f. Chang''s in Corte Madera at + afternoon 12?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'SYSTEM: Please confirm your reservation at P.f. Chang''s in Corte Madera at + 12 pm for 2 on March 8th.' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 3132 +num_examples: 67286 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml new file mode 100644 index 00000000000..08800de222d --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml @@ -0,0 +1,45 @@ +acts: +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_SystemTeacher + labels: + - 'APICALL: ' + slots: {} + text: 'USER: I am feeling hungry so I would like to find a place to eat.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_SystemTeacher + labels: + - 'SYSTEM: Do you have a specific which you want the eating place to be located + at?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_SystemTeacher + labels: + - 'APICALL: ' + slots: {} + text: 'USER: I would like for it to be in San Jose.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_SystemTeacher + labels: + - 'SYSTEM: Is there a specific cuisine type you enjoy, such as Mexican, Italian + or something else?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 13888 +num_examples: 320622 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml new file mode 100644 index 00000000000..92f565fea21 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml @@ -0,0 +1,45 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'APIS: ' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'APICALL: ' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'USER: I want to make a restaurant reservation for 2 people at half past + 11 in the morning.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'SYSTEM: What city do you want to dine in? Do you have a preferred restaurant?' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'APICALL: ' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'USER: Please find restaurants in San Jose. Can you try Sino?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'SYSTEM: Confirming: I will reserve a table for 2 people at Sino in San Jose. + The reservation time is 11:30 am today.' + id: GoogleSgdInDomain_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 1966 +num_examples: 42808 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..2934f561f8c --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml @@ -0,0 +1,51 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Hi, could you get me a restaurant booking on the 8th please?' + id: GoogleSgdInDomain_UserSimulatorTeacher + text: 'GOAL: api_name = ReserveRestaurant ; date = 2019-03-08 ; location = Corte + Madera ; number_of_seats = 2 ; restaurant_name = P.f. Chang''s ; time = 12:00 + | api_name = ReserveRestaurant ; date = 2019-03-08 ; location = Corte Madera + ; number_of_seats = 2 ; restaurant_name = Benissimo Restaurant & Bar ; time + = 12:00' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Could you get me a reservation at P.f. Chang''s in Corte Madera at afternoon + 12?' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Any preference on the restaurant, location and time?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Sure, that is great.' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Please confirm your reservation at P.f. Chang''s in Corte Madera + at 12 pm for 2 on March 8th.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Could you try booking a table at Benissimo instead?' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sorry, your reservation could not be made. Could I help you with + something else?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Sure, may I know if they have vegetarian options and how expensive is + their food?' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sure, please confirm your reservation at Benissimo Restaurant & + Bar in Corte Madera at 12 pm for 2 on March 8th.' + type: 'USER: ' +num_episodes: 3132 +num_examples: 67286 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..37257b324ef --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml @@ -0,0 +1,49 @@ +acts: +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_UserSimulatorTeacher + labels: + - 'USER: I am feeling hungry so I would like to find a place to eat.' + text: 'GOAL: api_name = FindRestaurants ; city = San Jose ; cuisine = American + | api_name = FindRestaurants ; city = Palo Alto ; cuisine = American ; price_range + = moderate | api_name = ReserveRestaurant ; city = Palo Alto ; date = 2019-03-01 + ; party_size = 2 ; restaurant_name = Bird Dog ; time = 11:30' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_UserSimulatorTeacher + labels: + - 'USER: I would like for it to be in San Jose.' + slots: {} + text: 'SYSTEM: Do you have a specific which you want the eating place to be located + at?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_UserSimulatorTeacher + labels: + - 'USER: I usually like eating the American type of food.' + slots: {} + text: 'SYSTEM: Is there a specific cuisine type you enjoy, such as Mexican, Italian + or something else?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_UserSimulatorTeacher + labels: + - 'USER: Can you give me the address of this restaurant.' + slots: {} + text: 'SYSTEM: I see that at 71 Saint Peter there is a good restaurant which is + in San Jose.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSgdInDomain_UserSimulatorTeacher + labels: + - 'USER: Can you give me the phone number that I can contact them with?' + slots: {} + text: 'SYSTEM: If you want to go to this restaurant you can find it at 71 North + San Pedro Street.' + type: 'USER: ' +num_episodes: 13888 +num_examples: 320622 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..c6d888e6f57 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml @@ -0,0 +1,46 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: I want to make a restaurant reservation for 2 people at half past 11 + in the morning.' + id: GoogleSgdInDomain_UserSimulatorTeacher + text: 'GOAL: api_name = ReserveRestaurant ; date = 2019-03-01 ; location = San + Jose ; number_of_seats = 2 ; restaurant_name = Sino ; time = 11:30' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Please find restaurants in San Jose. Can you try Sino?' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: What city do you want to dine in? Do you have a preferred restaurant?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Yes, thanks. What''s their phone number?' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Confirming: I will reserve a table for 2 people at Sino in San + Jose. The reservation time is 11:30 am today.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: What''s their address? Do they have vegetarian options on their menu?' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Your reservation has been made. Their phone number is 408-247-8880.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Thanks very much.' + id: GoogleSgdInDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: The street address is 377 Santana Row #1000. They have good vegetarian + options.' + type: 'USER: ' +num_episodes: 1966 +num_examples: 42808 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml new file mode 100644 index 00000000000..4d9686d40f2 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml @@ -0,0 +1,54 @@ +acts: +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'APIS: ' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'APICALL: ' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'USER: I''m looking for some attractions to visit while I''m vacationing.' + type: 'APICALL: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'SYSTEM: Sure, where will you be visiting?' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'APICALL: api_name = FindAttractions ; location = London' + id: GoogleSgdOutDomain_SystemTeacher + slots: + api_name: FindAttractions + location: London + text: 'USER: I''m looking for things to do in London.' + type: 'APICALL: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'SYSTEM: Alright, one of the top attractions is 30 St Mary Axe (The Gherkin). + It''s a historical landmark.' + id: GoogleSgdOutDomain_SystemTeacher + slots: + attraction_name: 30 St Mary Axe (The Gherkin) + category: Historical Landmark + free_entry: 'False' + good_for_kids: 'True' + location: London + phone_number: 20 7071 5029 + text: 'APIRESP: attraction_name = 30 St Mary Axe (The Gherkin) ; category = Historical + Landmark ; free_entry = False ; good_for_kids = True ; location = London ; phone_number + = 20 7071 5029' + type: 'SYSTEM: ' +num_episodes: 3132 +num_examples: 67286 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml new file mode 100644 index 00000000000..2792ba26440 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml @@ -0,0 +1,63 @@ +acts: +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_SystemTeacher + labels: + - 'APICALL: ' + slots: {} + text: 'USER: I would like to find a SUV till 7th of March.' + type: 'APICALL: ' +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_SystemTeacher + labels: + - 'SYSTEM: From where do you want to rent the car from? Also at which time do + you want to pick up the car and start the car rental?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_SystemTeacher + labels: + - 'APICALL: api_name = GetCarsAvailable ; car_type = SUV ; city = Long Beach ; + end_date = 2019-03-07 ; pickup_time = 11:00 ; start_date = 2019-03-04' + slots: + api_name: GetCarsAvailable + car_type: SUV + city: Long Beach + end_date: '2019-03-07' + pickup_time: '11:00' + start_date: '2019-03-04' + text: 'USER: I am looking for cars in Long Beach, CA and want to pick it up at + 11 o"clock in the morning from 4th of this month.' + type: 'APICALL: ' +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_SystemTeacher + labels: + - 'SYSTEM: There is a 1 car available at Downtown Station and it is a SUV of Alfa + Romeo Stelvio.' + slots: + car_name: Alfa Romeo Stelvio + car_type: SUV + city: Long Beach + end_date: '2019-03-07' + pickup_location: Downtown Station + pickup_time: '11:00' + price_per_day: '38.00' + start_date: '2019-03-04' + text: 'APIRESP: car_name = Alfa Romeo Stelvio ; car_type = SUV ; city = Long Beach + ; end_date = 2019-03-07 ; pickup_location = Downtown Station ; pickup_time = + 11:00 ; price_per_day = 38.00 ; start_date = 2019-03-04' + type: 'SYSTEM: ' +num_episodes: 13888 +num_examples: 320622 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml new file mode 100644 index 00000000000..be07f99ea7e --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml @@ -0,0 +1,43 @@ +acts: +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'APIS: ' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'APICALL: ' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'USER: i''m looking to move house so i''m after some options' + type: 'APICALL: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'SYSTEM: Are you looking to rent or buy? How many room will you need?' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'APICALL: ' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'USER: Umm three bedrooms would be good and i want to buy' + type: 'APICALL: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'SYSTEM: What are are you interested in?' + id: GoogleSgdOutDomain_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 1966 +num_examples: 42808 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..9ba470ac70d --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml @@ -0,0 +1,52 @@ +acts: +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'USER: I''m looking for some attractions to visit while I''m vacationing.' + id: GoogleSgdOutDomain_UserSimulatorTeacher + text: 'GOAL: api_name = FindAttractions ; location = London | api_name = GetCarsAvailable + ; city = London ; end_date = 2019-03-05 ; pickup_time = 09:30 ; start_date = + 2019-03-01 | api_name = GetCarsAvailable ; car_type = SUV ; city = London ; + end_date = 2019-03-05 ; pickup_time = 09:30 ; start_date = 2019-03-01 | add_insurance + = False ; api_name = ReserveCar ; car_type = SUV ; end_date = 2019-03-05 ; pickup_location + = Heathrow International Airport ; pickup_time = 09:30 ; start_date = 2019-03-01 + | api_name = ShareLocation ; contact_name = Emma ; location = Heathrow International + Airport' + type: 'USER: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'USER: I''m looking for things to do in London.' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sure, where will you be visiting?' + type: 'USER: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'USER: That sounds perfect. Can you help me find a rental car there?' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Alright, one of the top attractions is 30 St Mary Axe (The Gherkin). + It''s a historical landmark.' + type: 'USER: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'USER: I''ll pick it up in the morning 9:30 and I''ll need it until Tuesday + next week.' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sure, what time would you like to get the car? Also, when will + you be dropping it off?' + type: 'USER: ' +- - domain: Messaging, RentalCars, Travel + episode_done: false + eval_labels: + - 'USER: Yes, that''s right.' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Alright, will you be picking it up in London?' + type: 'USER: ' +num_episodes: 3132 +num_examples: 67286 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..289b95c3333 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml @@ -0,0 +1,49 @@ +acts: +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_UserSimulatorTeacher + labels: + - 'USER: I would like to find a SUV till 7th of March.' + text: 'GOAL: api_name = GetCarsAvailable ; car_type = SUV ; city = Long Beach + ; end_date = 2019-03-07 ; pickup_time = 11:00 ; start_date = 2019-03-04' + type: 'USER: ' +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_UserSimulatorTeacher + labels: + - 'USER: I am looking for cars in Long Beach, CA and want to pick it up at 11 + o"clock in the morning from 4th of this month.' + slots: {} + text: 'SYSTEM: From where do you want to rent the car from? Also at which time + do you want to pick up the car and start the car rental?' + type: 'USER: ' +- - domain: RentalCars + episode_done: false + id: GoogleSgdOutDomain_UserSimulatorTeacher + labels: + - 'USER: This sounds good. That''s all I needed.' + slots: {} + text: 'SYSTEM: There is a 1 car available at Downtown Station and it is a SUV + of Alfa Romeo Stelvio.' + type: 'USER: ' +- - domain: RentalCars + episode_done: true + id: GoogleSgdOutDomain_UserSimulatorTeacher + labels: + - 'USER: [DONE]' + slots: {} + text: 'SYSTEM: Have a nice day.' + type: 'USER: ' +- - domain: RentalCars, Travel + episode_done: false + id: GoogleSgdOutDomain_UserSimulatorTeacher + labels: + - 'USER: I need a rental car in Fresno' + text: 'GOAL: api_name = GetCarsAvailable ; dropoff_date = 2019-03-08 ; pickup_city + = Fresno ; pickup_date = 2019-03-07 ; pickup_time = 16:30 | api_name = ReserveCar + ; dropoff_date = 2019-03-08 ; pickup_date = 2019-03-07 ; pickup_location = Fresno + Station ; pickup_time = 16:30 ; type = Standard | api_name = FindAttractions + ; location = Fresno' + type: 'USER: ' +num_episodes: 13888 +num_examples: 320622 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..d63e986ddb5 --- /dev/null +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml @@ -0,0 +1,47 @@ +acts: +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'USER: i''m looking to move house so i''m after some options' + id: GoogleSgdOutDomain_UserSimulatorTeacher + text: 'GOAL: api_name = FindHomeByArea ; area = Sunnyvale ; intent = buy ; number_of_baths + = 2 ; number_of_beds = 3 | api_name = ScheduleVisit ; property_name = Apricot + Pit Apartments ; visit_date = 2019-03-06 | api_name = ShareLocation ; contact_name + = Sophia ; location = 400 East Remington Drive | api_name = GetRide ; destination + = 400 East Remington Drive ; number_of_seats = 1 ; ride_type = Regular' + type: 'USER: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'USER: Umm three bedrooms would be good and i want to buy' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Are you looking to rent or buy? How many room will you need?' + type: 'USER: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'USER: can you look for me in sunnyvale' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: What are are you interested in?' + type: 'USER: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'USER: can you find places with two bathrooms' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: How many bathrooms do you need' + type: 'USER: ' +- - domain: Homes, Messaging, RideSharing + episode_done: false + eval_labels: + - 'USER: Does this place have a garage?' + id: GoogleSgdOutDomain_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: I have 10 properties that might suit, including Apricot pit apartments + 400 east remington drive. The listing price is $3,650,000' + type: 'USER: ' +num_episodes: 1966 +num_examples: 42808 From 9426997d0eac5a3b600d8415c74b5588cbddd782 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 13:36:14 -0800 Subject: [PATCH 14/57] [TOD][Datasets][Easy] MetalWoz Code for process MetalWoz into System + User Simulator teachers Getting it to be in the Conversations format is a pain, so I don't even try here. (It's documented this way in the paper as well) ---------------------------- Datasets added in this substack: * Google SGD * Google SGD Simulation Splits (In-domain, Out-domain) * **MetalWoz** * MSR_E2E * Multidogo * MultiWoz V2.2 * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/metalwoz/agents.py | 118 +++++++++++++----- parlai/tasks/metalwoz/build.py | 26 ++-- parlai/tasks/metalwoz/test.py | 4 + .../metalwoz_UserSimulatorTeacher_test.yml | 78 ++++++++++++ .../metalwoz_UserSimulatorTeacher_train.yml | 70 +++++++++++ .../metalwoz_UserSimulatorTeacher_valid.yml | 71 +++++++++++ parlai/tasks/metalwoz/test/metalwoz_test.yml | 106 ++++++++-------- parlai/tasks/metalwoz/test/metalwoz_train.yml | 102 +++++++-------- parlai/tasks/metalwoz/test/metalwoz_valid.yml | 110 ++++++++-------- 9 files changed, 481 insertions(+), 204 deletions(-) create mode 100644 parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/metalwoz/agents.py b/parlai/tasks/metalwoz/agents.py index a217c4e8a8f..f3f9d2c8238 100644 --- a/parlai/tasks/metalwoz/agents.py +++ b/parlai/tasks/metalwoz/agents.py @@ -4,22 +4,36 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser from parlai.core.teachers import DialogTeacher from parlai.utils.io import PathManager from parlai.utils.data import DatatypeHelper from .build import build import os import pandas as pd -import hashlib +from typing import Optional -class MetalWozTeacher(DialogTeacher): +class MetalWozTeacherBase(DialogTeacher): + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--metalwoz-domains", + nargs="+", + help="Use only a subset of the domains", + ) + return parser + def _path(self, opt): - fold = DatatypeHelper.fold(opt['datatype']) - if fold == 'train' or fold == 'valid': - folder = os.path.join(opt['datapath'], 'metalwoz', 'train') + fold = DatatypeHelper.fold(opt["datatype"]) + if fold == "train" or fold == "valid": + folder = os.path.join(opt["datapath"], "metalwoz", "train") else: - folder = os.path.join(opt['datapath'], 'metalwoz', 'test') + folder = os.path.join(opt["datapath"], "metalwoz", "test") return folder, fold def __init__(self, opt, shared=None): @@ -27,51 +41,89 @@ def __init__(self, opt, shared=None): build(opt) folder, fold = self._path(opt) self.fold = fold - opt['datafile'] = os.path.join(folder, fold) + opt["datafile"] = os.path.join(folder, fold) super().__init__(opt, shared) - def _hash(self, string): - return int(hashlib.sha1(string.encode('utf-8')).hexdigest(), 16) % 10 - - def setup_data(self, datapath): + def load_data(self, datapath): folder, fold = os.path.split(datapath) - with PathManager.open(os.path.join(folder, 'tasks.txt')) as taskf: + with PathManager.open(os.path.join(folder, "tasks.txt")) as taskf: tasks_table = pd.read_json(taskf, lines=True) - dfolder = os.path.join(folder, 'dialogues') + dfolder = os.path.join(folder, "dialogues") data = [] for filename in PathManager.ls(dfolder): + domain = filename.replace(".txt", "") + if ( + self.opt["metalwoz_domains"] + and domain not in self.opt["metalwoz_domains"] + ): + continue fullfn = os.path.join(dfolder, filename) with PathManager.open(fullfn) as dataf: - data.append(pd.read_json(dataf, lines=True)) + lines = pd.read_json(dataf, lines=True) + lines = lines.merge(tasks_table, on="task_id") + data.append(lines.to_dict("records")) - data = pd.concat(data, axis=0) - data = data.sample(frac=1.0, random_state=83741) # metal in l33t numbers, lol - data = data.merge(tasks_table, on='task_id') - data['fold'] = data['domain_x'].apply(self._hash) + # Quick check to make sure we didn't fat-finger the spelling of some domain + if self.opt["metalwoz_domains"]: + assert len(data) == len(self.opt["metalwoz_domains"]) - for _, row in data.iterrows(): - if fold == 'valid' and row['fold'] != 9: - continue - if fold == 'train' and row['fold'] == 9: - continue - texts = [row['bot_role']] + list(row['turns']) + if "test" in self.fold: + flat = [] + for domain in data: + flat.extend(domain) + return flat + + return DatatypeHelper.split_subset_data_by_fold(self.fold, data, 0.8, 0.1, 0.1) + + +class SystemTeacher(MetalWozTeacherBase): + def setup_data(self, datapath): + data = self.load_data(datapath) + for row in data: + texts = [row["bot_role"]] + list(row["turns"]) prompts, labels = texts[::2], texts[1::2] for i, (prompt, label) in enumerate(zip(prompts, labels)): yield { - 'text': prompt, - 'label': label, - 'bot_role': row['bot_role'], - 'bot_prompt': row['bot_prompt'], - 'user_role': row['user_role'], - 'user_prompt': row['user_prompt'], - 'utterance_id': row['id'], - 'domain': row['domain_x'], - 'task_id': row['task_id'], + "text": prompt, + "label": label, + "bot_role": row["bot_role"], + "bot_prompt": row["bot_prompt"], + "user_role": row["user_role"], + "user_prompt": row["user_prompt"], + "utterance_id": row["id"], + "domain": row["domain_x"], + "task_id": row["task_id"], }, i == 0 +class UserSimulatorTeacher(MetalWozTeacherBase): + def setup_data(self, datapath): + data = self.load_data(datapath) + for row in data: + texts = list(row["turns"]) + prompts, labels = [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], texts[ + 1::2 + ] + for i, (prompt, label) in enumerate(zip(prompts, labels)): + yield { + "text": prompt, + "label": label, + "bot_role": row["bot_role"], + "bot_prompt": row["bot_prompt"], + "user_role": row["user_role"], + "user_prompt": row["user_prompt"], + "utterance_id": row["id"], + "domain": row["domain_x"], + "task_id": row["task_id"], + }, i == 0 + + +class MetalWozTeacher(SystemTeacher): + pass + + class DefaultTeacher(MetalWozTeacher): pass diff --git a/parlai/tasks/metalwoz/build.py b/parlai/tasks/metalwoz/build.py index aecfaaa6ee5..451fb34fc14 100644 --- a/parlai/tasks/metalwoz/build.py +++ b/parlai/tasks/metalwoz/build.py @@ -10,32 +10,32 @@ RESOURCES = [ DownloadableFile( - 'https://download.microsoft.com/download/E/B/8/EB84CB1A-D57D-455F-B905-3ABDE80404E5/metalwoz-v1.zip', - 'metalwoz-v1.zip', - '2a2ae3b25760aa2725e70bc6480562fa5d720c9689a508d28417631496d6764f', + "https://download.microsoft.com/download/E/B/8/EB84CB1A-D57D-455F-B905-3ABDE80404E5/metalwoz-v1.zip", + "metalwoz-v1.zip", + "2a2ae3b25760aa2725e70bc6480562fa5d720c9689a508d28417631496d6764f", ), DownloadableFile( - 'https://download.microsoft.com/download/0/c/4/0c4a8893-cbf9-4a43-a44a-09bab9539234/metalwoz-test-v1.zip', - 'metalwoz-test-v1.zip', - '6722d1d9ec05334dd801972767ae3bdefcd15f71bf73fea1d098f214a96a7c6c', + "https://download.microsoft.com/download/0/c/4/0c4a8893-cbf9-4a43-a44a-09bab9539234/metalwoz-test-v1.zip", + "metalwoz-test-v1.zip", + "6722d1d9ec05334dd801972767ae3bdefcd15f71bf73fea1d098f214a96a7c6c", ), ] def build(opt): - dpath = os.path.join(opt['datapath'], 'metalwoz') - version = '1.0' + dpath = os.path.join(opt["datapath"], "metalwoz") + version = "1.0" if not build_data.built(dpath, version_string=version): if build_data.built(dpath): build_data.remove_dir(dpath) build_data.make_dir(dpath) - build_data.make_dir(os.path.join(dpath, 'train', 'dialogues')) - build_data.make_dir(os.path.join(dpath, 'test', 'dialogues')) + build_data.make_dir(os.path.join(dpath, "train", "dialogues")) + build_data.make_dir(os.path.join(dpath, "test", "dialogues")) # Download the data. - RESOURCES[0].download_file(os.path.join(dpath, 'train')) - RESOURCES[1].download_file(os.path.join(dpath, 'test')) + RESOURCES[0].download_file(os.path.join(dpath, "train")) + RESOURCES[1].download_file(os.path.join(dpath, "test")) - build_data.untar(os.path.join(dpath, 'test'), 'dstc8_metalwoz_heldout.zip') + build_data.untar(os.path.join(dpath, "test"), "dstc8_metalwoz_heldout.zip") build_data.mark_done(dpath, version_string=version) diff --git a/parlai/tasks/metalwoz/test.py b/parlai/tasks/metalwoz/test.py index ffbd99d2248..310564b7969 100644 --- a/parlai/tasks/metalwoz/test.py +++ b/parlai/tasks/metalwoz/test.py @@ -9,3 +9,7 @@ class TestDefaultTeacher(AutoTeacherTest): task = "metalwoz" + + +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "metalwoz:UserSimulatorTeacher" diff --git a/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_test.yml b/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..5b4c5e92bb1 --- /dev/null +++ b/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_test.yml @@ -0,0 +1,78 @@ +acts: +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT + episode_done: false + eval_labels: + - Hello, I want to book a flight. + id: metalwoz:UserSimulatorTeacher + task_id: b5ca362f + text: 'You are interacting with a bot designed to book flights + + Hello how may I help you?' + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT + episode_done: false + eval_labels: + - Yes that's where I want to go. + id: metalwoz:UserSimulatorTeacher + task_id: b5ca362f + text: Can I help you book a flight to greece? + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT + episode_done: false + eval_labels: + - I want to leave from Dallas. + id: metalwoz:UserSimulatorTeacher + task_id: b5ca362f + text: Perfect, and when would you like to leave? + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT + episode_done: false + eval_labels: + - I want to leave by next Friday. + id: metalwoz:UserSimulatorTeacher + task_id: b5ca362f + text: Alright, i have your flight leaving from dallas and landing in athens. When + would you like to fly there? + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT + episode_done: true + eval_labels: + - Wait please don't do that, I was only curious. I don't actually want to book + a flight + id: metalwoz:UserSimulatorTeacher + task_id: b5ca362f + text: Alright booking your flight between now and next friday. You will be alerted + when your flight is book and vased on the lowest available price. + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +num_episodes: 2319 +num_examples: 14067 diff --git a/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_train.yml b/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..9e571d8d45c --- /dev/null +++ b/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_train.yml @@ -0,0 +1,70 @@ +acts: +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO + episode_done: false + id: metalwoz:UserSimulatorTeacher + labels: + - I would like to ask a question Can you provide general information? + task_id: 4fdf58c3 + text: 'You are interacting with a bot designed to fetch information from the internet + + Hello how may I help you?' + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO + episode_done: false + id: metalwoz:UserSimulatorTeacher + labels: + - What is the meaning of life? + task_id: 4fdf58c3 + text: 'yes' + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO + episode_done: false + id: metalwoz:UserSimulatorTeacher + labels: + - Does life have meaning? + task_id: 4fdf58c3 + text: Sorry, I don't have the information that you are asking. + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO + episode_done: false + id: metalwoz:UserSimulatorTeacher + labels: + - How many Oreos can a gorilla eat? + task_id: 4fdf58c3 + text: Sorry, I don't have the information that you are asking. + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO + episode_done: true + id: metalwoz:UserSimulatorTeacher + labels: + - Does my butt look big? + task_id: 4fdf58c3 + text: Sorry, I don't have the information that you are asking. + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +num_episodes: 30287 +num_examples: 185359 diff --git a/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_valid.yml b/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..8627d2b7041 --- /dev/null +++ b/parlai/tasks/metalwoz/test/metalwoz_UserSimulatorTeacher_valid.yml @@ -0,0 +1,71 @@ +acts: +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO + episode_done: false + eval_labels: + - Hey, I'd like to compare two cities + id: metalwoz:UserSimulatorTeacher + task_id: 284cf7ff + text: 'You are interacting with a bot that provides facts about different cities + + Hello how may I help you?' + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO + episode_done: false + eval_labels: + - Please compare Los Angeles with New York City + id: metalwoz:UserSimulatorTeacher + task_id: 284cf7ff + text: Sure, which cities? + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO + episode_done: false + eval_labels: + - Sure. What's the difference in population? + id: metalwoz:UserSimulatorTeacher + task_id: 284cf7ff + text: Can you narrow down what you want to compare? + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO + episode_done: false + eval_labels: + - Oh wow, that's a huge difference. How big is New York? + id: metalwoz:UserSimulatorTeacher + task_id: 284cf7ff + text: Los Angeles has a population of 3.976 million, NYC has a population of 8.538 + million + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO + episode_done: false + eval_labels: + - And how big is LA? + id: metalwoz:UserSimulatorTeacher + task_id: 284cf7ff + text: 304.6 mi + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +num_episodes: 3788 +num_examples: 23165 diff --git a/parlai/tasks/metalwoz/test/metalwoz_test.yml b/parlai/tasks/metalwoz/test/metalwoz_test.yml index abb9c9b56e9..3c187acfe38 100644 --- a/parlai/tasks/metalwoz/test/metalwoz_test.yml +++ b/parlai/tasks/metalwoz/test/metalwoz_test.yml @@ -1,69 +1,75 @@ acts: -- - bot_prompt: Fulfil the user's request - bot_role: You are a bot that provides tourism related advice - domain: TOURISM +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT episode_done: false eval_labels: - Hello how may I help you? id: metalwoz - task_id: 290f924c - text: You are a bot that provides tourism related advice - user_prompt: Tell the Bot that you are heading to Montreal in the summer, and - ask if there are any good festivals around that time - user_role: You are interacting with a bot that gives tourism related advice - utterance_id: b942824c -- - bot_prompt: Fulfil the user's request - bot_role: You are a bot that provides tourism related advice - domain: TOURISM + task_id: b5ca362f + text: You are a bot designed to book flights + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT episode_done: false eval_labels: - - I can assist you with your upcoming Montreal trip. What do you need? + - Can I help you book a flight to greece? id: metalwoz - task_id: 290f924c - text: I have some questions about my upcoming travel to Montreal. - user_prompt: Tell the Bot that you are heading to Montreal in the summer, and - ask if there are any good festivals around that time - user_role: You are interacting with a bot that gives tourism related advice - utterance_id: b942824c -- - bot_prompt: Fulfil the user's request - bot_role: You are a bot that provides tourism related advice - domain: TOURISM + task_id: b5ca362f + text: Hello, I want to book a flight. + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT episode_done: false eval_labels: - - 'Here are a few festivals: BarnFest Steakfest Musicfest' + - Perfect, and when would you like to leave? id: metalwoz - task_id: 290f924c - text: I'm heading there shortly during what is there summer months and was wondering - if there are good festivals going on then. - user_prompt: Tell the Bot that you are heading to Montreal in the summer, and - ask if there are any good festivals around that time - user_role: You are interacting with a bot that gives tourism related advice - utterance_id: b942824c -- - bot_prompt: Fulfil the user's request - bot_role: You are a bot that provides tourism related advice - domain: TOURISM + task_id: b5ca362f + text: Yes that's where I want to go. + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT episode_done: false eval_labels: - - They are all occuring in July. + - Alright, i have your flight leaving from dallas and landing in athens. When + would you like to fly there? id: metalwoz - task_id: 290f924c - text: Those sound good. Do you have the dates they are happening? - user_prompt: Tell the Bot that you are heading to Montreal in the summer, and - ask if there are any good festivals around that time - user_role: You are interacting with a bot that gives tourism related advice - utterance_id: b942824c -- - bot_prompt: Fulfil the user's request - bot_role: You are a bot that provides tourism related advice - domain: TOURISM + task_id: b5ca362f + text: I want to leave from Dallas. + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 +- - bot_prompt: Offer to help the user book a flight to Greece. + bot_role: You are a bot designed to book flights + domain: BOOKING_FLIGHT episode_done: false eval_labels: - - Sounds fun! Can I help you with anything else today? + - Alright booking your flight between now and next friday. You will be alerted + when your flight is book and vased on the lowest available price. id: metalwoz - task_id: 290f924c - text: That is when I plan on going. They all sound good to me. - user_prompt: Tell the Bot that you are heading to Montreal in the summer, and - ask if there are any good festivals around that time - user_role: You are interacting with a bot that gives tourism related advice - utterance_id: b942824c + task_id: b5ca362f + text: I want to leave by next Friday. + user_prompt: You would like to know how to get a flight to Greece. If the bot + starts booking you a flight to Greece, tell them you were only curious and that + you do not wish to actually book the flight. + user_role: You are interacting with a bot designed to book flights + utterance_id: e0868745 num_episodes: 2319 num_examples: 14067 diff --git a/parlai/tasks/metalwoz/test/metalwoz_train.yml b/parlai/tasks/metalwoz/test/metalwoz_train.yml index 2214acfb1e4..6a8256bb740 100644 --- a/parlai/tasks/metalwoz/test/metalwoz_train.yml +++ b/parlai/tasks/metalwoz/test/metalwoz_train.yml @@ -1,68 +1,68 @@ acts: -- - bot_prompt: Tell the user that there are no ski hills in their immediate location - bot_role: You are a bot that helps people book skiing trips - domain: SKI_BOT +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO episode_done: false id: metalwoz labels: - Hello how may I help you? - task_id: 2511cf64 - text: You are a bot that helps people book skiing trips - user_prompt: You want to know if there are good ski hills an hour's drive from - your current location - user_role: You are interacting with a bot designed to help you book a skiing trip - utterance_id: c3ec2179 -- - bot_prompt: Tell the user that there are no ski hills in their immediate location - bot_role: You are a bot that helps people book skiing trips - domain: SKI_BOT + task_id: 4fdf58c3 + text: You are a bot designed to fetch information from the internet + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO episode_done: false id: metalwoz labels: - - There are no ski hills in your location - task_id: 2511cf64 - text: Are there any ski resorts me? - user_prompt: You want to know if there are good ski hills an hour's drive from - your current location - user_role: You are interacting with a bot designed to help you book a skiing trip - utterance_id: c3ec2179 -- - bot_prompt: Tell the user that there are no ski hills in their immediate location - bot_role: You are a bot that helps people book skiing trips - domain: SKI_BOT + - 'yes' + task_id: 4fdf58c3 + text: I would like to ask a question Can you provide general information? + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO episode_done: false id: metalwoz labels: - - In the mount le - task_id: 2511cf64 - text: What's the nearest ski resort? - user_prompt: You want to know if there are good ski hills an hour's drive from - your current location - user_role: You are interacting with a bot designed to help you book a skiing trip - utterance_id: c3ec2179 -- - bot_prompt: Tell the user that there are no ski hills in their immediate location - bot_role: You are a bot that helps people book skiing trips - domain: SKI_BOT + - Sorry, I don't have the information that you are asking. + task_id: 4fdf58c3 + text: What is the meaning of life? + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO episode_done: false id: metalwoz labels: - - 4 hrs - task_id: 2511cf64 - text: How many hours way is that from me? - user_prompt: You want to know if there are good ski hills an hour's drive from - your current location - user_role: You are interacting with a bot designed to help you book a skiing trip - utterance_id: c3ec2179 -- - bot_prompt: Tell the user that there are no ski hills in their immediate location - bot_role: You are a bot that helps people book skiing trips - domain: SKI_BOT + - Sorry, I don't have the information that you are asking. + task_id: 4fdf58c3 + text: Does life have meaning? + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +- - bot_prompt: Tell the user you don't have the information they are asking for + bot_role: You are a bot designed to fetch information from the internet + domain: LOOK_UP_INFO episode_done: false id: metalwoz labels: - - is there anything else? - task_id: 2511cf64 - text: Okay thanks - user_prompt: You want to know if there are good ski hills an hour's drive from - your current location - user_role: You are interacting with a bot designed to help you book a skiing trip - utterance_id: c3ec2179 -num_episodes: 31677 -num_examples: 194324 + - Sorry, I don't have the information that you are asking. + task_id: 4fdf58c3 + text: How many Oreos can a gorilla eat? + user_prompt: Ask the bot a question about common world knowledge + user_role: You are interacting with a bot designed to fetch information from the + internet + utterance_id: 2454253d +num_episodes: 30287 +num_examples: 185359 diff --git a/parlai/tasks/metalwoz/test/metalwoz_valid.yml b/parlai/tasks/metalwoz/test/metalwoz_valid.yml index a29b7debbcb..129fc3a980e 100644 --- a/parlai/tasks/metalwoz/test/metalwoz_valid.yml +++ b/parlai/tasks/metalwoz/test/metalwoz_valid.yml @@ -1,73 +1,69 @@ acts: -- - bot_prompt: Reply to the customer and try to fulfil their request. If you think - the request they are making goes beyond your role, inform the user that you - are not equipped to help them - bot_role: You are a bot that clarifies the rules for games - domain: GAME_RULES +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO episode_done: false eval_labels: - - Hello how may I help you? hello. + - Hello how may I help you? id: metalwoz - task_id: a5137c64 - text: You are a bot that clarifies the rules for games - user_prompt: Start a conversation based on you customerRole - user_role: You are interacting with a bot that clarifies the rules of games - utterance_id: 194e1958 -- - bot_prompt: Reply to the customer and try to fulfil their request. If you think - the request they are making goes beyond your role, inform the user that you - are not equipped to help them - bot_role: You are a bot that clarifies the rules for games - domain: GAME_RULES + task_id: 284cf7ff + text: You are a bot that provides facts about different cities + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO episode_done: false eval_labels: - - Yes sure. What game ? + - Sure, which cities? id: metalwoz - task_id: a5137c64 - text: Hi, can you help me with a game? - user_prompt: Start a conversation based on you customerRole - user_role: You are interacting with a bot that clarifies the rules of games - utterance_id: 194e1958 -- - bot_prompt: Reply to the customer and try to fulfil their request. If you think - the request they are making goes beyond your role, inform the user that you - are not equipped to help them - bot_role: You are a bot that clarifies the rules for games - domain: GAME_RULES + task_id: 284cf7ff + text: Hey, I'd like to compare two cities + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO episode_done: false eval_labels: - - No there isn't. + - Can you narrow down what you want to compare? id: metalwoz - task_id: a5137c64 - text: Okay, I need to know if theres a rule for who goes first in checkers - user_prompt: Start a conversation based on you customerRole - user_role: You are interacting with a bot that clarifies the rules of games - utterance_id: 194e1958 -- - bot_prompt: Reply to the customer and try to fulfil their request. If you think - the request they are making goes beyond your role, inform the user that you - are not equipped to help them - bot_role: You are a bot that clarifies the rules for games - domain: GAME_RULES + task_id: 284cf7ff + text: Please compare Los Angeles with New York City + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO episode_done: false eval_labels: - - Yes as long as you take move. + - Los Angeles has a population of 3.976 million, NYC has a population of 8.538 + million id: metalwoz - task_id: a5137c64 - text: So it doesn't matter who goes first? - user_prompt: Start a conversation based on you customerRole - user_role: You are interacting with a bot that clarifies the rules of games - utterance_id: 194e1958 -- - bot_prompt: Reply to the customer and try to fulfil their request. If you think - the request they are making goes beyond your role, inform the user that you - are not equipped to help them - bot_role: You are a bot that clarifies the rules for games - domain: GAME_RULES + task_id: 284cf7ff + text: Sure. What's the difference in population? + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +- - bot_prompt: Ask the user to narrow down their request a little + bot_role: You are a bot that provides facts about different cities + domain: CITY_INFO episode_done: false eval_labels: - - Yes probably. + - 304.6 mi id: metalwoz - task_id: a5137c64 - text: Okay, I guess we'll just take turns then - user_prompt: Start a conversation based on you customerRole - user_role: You are interacting with a bot that clarifies the rules of games - utterance_id: 194e1958 -num_episodes: 6207 -num_examples: 37545 + task_id: 284cf7ff + text: Oh wow, that's a huge difference. How big is New York? + user_prompt: Ask the bot to compare LA and New York City + user_role: You are interacting with a bot that provides facts about different + cities + utterance_id: 4c18684f +num_episodes: 3788 +num_examples: 23165 From 9b2116cebf6055a85f958fee5052f3aecfae418b Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 13:39:47 -0800 Subject: [PATCH 15/57] [TOD][Datasets][Easy] MSR E2E into TOD Conversations format Title. I only include System + UserSimulator Teachers here since that's all we need right now from dataset. Datasets added in this substack: * Google SGD * Google SGD Simulation Splits (In-domain, Out-domain) * MetalWoz * **MSR_E2E** * Multidogo * MultiWoz V2.2 * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/msr_e2e/agents.py | 326 ++++++++++++++++++ parlai/tasks/msr_e2e/build.py | 60 ++++ parlai/tasks/msr_e2e/test.py | 15 + .../test/msr_e2e_SystemTeacher_test.yml | 71 ++++ .../test/msr_e2e_SystemTeacher_train.yml | 56 +++ .../test/msr_e2e_SystemTeacher_valid.yml | 55 +++ .../msr_e2e_UserSimulatorTeacher_test.yml | 50 +++ .../msr_e2e_UserSimulatorTeacher_train.yml | 48 +++ .../msr_e2e_UserSimulatorTeacher_valid.yml | 48 +++ 9 files changed, 729 insertions(+) create mode 100644 parlai/tasks/msr_e2e/agents.py create mode 100644 parlai/tasks/msr_e2e/build.py create mode 100644 parlai/tasks/msr_e2e/test.py create mode 100644 parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml create mode 100644 parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml create mode 100644 parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml create mode 100644 parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/msr_e2e/agents.py b/parlai/tasks/msr_e2e/agents.py new file mode 100644 index 00000000000..95e2122b896 --- /dev/null +++ b/parlai/tasks/msr_e2e/agents.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +MSR-E2E implementation for ParlAI. + +No official train/valid/test splits are available of public data, so we make our own. + +We assume inform slots from the agent are API responses and request/inform slots from the user is an API call. It's not quite how things are supposed to work, but the `dialogue act` setup is not super well standardized within the dataset. +""" + +from parlai.core.params import ParlaiParser +import copy +import os +import csv +from collections import Counter +from parlai.core.opt import Opt +import parlai.core.tod.tod_core as tod +from parlai.utils.misc import warn_once +from typing import Optional +from parlai.utils.data import DatatypeHelper +from parlai.utils.io import PathManager + +import parlai.tasks.msr_e2e.build as build_ +import parlai.core.tod.tod_agents as tod_agents + + +DOMAINS = [ + "movie", + "restaurant", + "taxi", +] + +# Just going to copy/paste these since it's faster than parsing 3 separate files +# They are in `system/src/deep_dialog/data_/_slots.txt` in the original data +SLOT_NAMES = { + "movie": [ + "actor", + "actress", + "city", + "closing", + "critic_rating", + "date", + "schema", + "distanceconstraints", + "genre", + "greeting", + "implicit_value", + "movie_series", + "moviename", + "mpaa_rating", + "numberofpeople", + "numberofkids", + "taskcomplete", + "other", + "price", + "seating", + "starttime", + "state", + "theater", + "theater_chain", + "video_format", + "zip", + "result", + "ticket", + "mc_list", + ], + "restaurant": [ + "address", + "atmosphere", + "choice", + "city", + "closing", + "cuisine", + "date", + "distanceconstraints", + "dress_code", + "food", + "greeting", + "implicit_value", + "mealtype", + "numberofpeople", + "numberofkids", + "occasion", + "other", + "personfullname", + "phonenumber", + "pricing", + "rating", + "restaurantname", + "restauranttype", + "seating", + "starttime", + "state", + "zip", + "result", + "mc_list", + "taskcomplete", + "reservation", + ], + "taxi": [ + "car_type", + "city", + "closing", + "date", + "distanceconstraints", + "dropoff_location", + "greeting", + "name", + "numberofpeople", + "other", + "pickup_location", + "dropoff_location_city", + "pickup_location_city", + "pickup_time", + "state", + "cost", + "taxi_company", + "mc_list", + "taskcomplete", + "taxi", + "zip", + "result", + "mc_list", + ], +} + +SLOT_NAMES = { + k: [{"api_name": k, "optArg": " | ".join(v)}] for k, v in SLOT_NAMES.items() +} + + +class MsrE2EParser(tod_agents.TodStructuredDataParser): + """ + Abstract data loader. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser.add_argument( + "--msre2e-domains", + nargs="+", + default=DOMAINS, + choices=DOMAINS, + help="Uses last passed in configuration.", + ) + parser.add_argument( + "--use-cumulative-api-calls", + type=bool, + default=True, + help="Have API Call/API response turns only when an API response" + "slot exist. Accumulate all API call slots with same API call name", + ) + return super().add_cmdline_args(parser, partial_opt) + + def __init__(self, opt: Opt, shared=None): + self.fold = DatatypeHelper.fold(opt["datatype"]) + opt["datafile"] = self.fold + self.dpath = os.path.join(opt["datapath"], "msr_e2e") + if shared is None: + warn_once("MsrE2E is a beta dataset, and format may significantly change.") + build_.build(opt) + super().__init__(opt, shared) + + def _load_data(self, fold, domains): + chunks = [] + for section in domains: + domain = [] + with PathManager.open(os.path.join(self.dpath, section + "_all.tsv")) as f: + reader = csv.reader(f, delimiter="\t") + next(reader) + lines = list(reader) + episode = [] + prev_idx = 0 + for line in lines: + data = {} + data["id"] = line[0] + data["speaker"] = line[3] + data["text"] = line[4] + data["dialogue_acts"] = line[5:] + data["domain"] = section + if prev_idx != data["id"]: + domain.append(episode) + episode = [] + prev_idx = data["id"] + episode.append(data) + domain.append(episode) + chunks.append(domain) + # deterministic shuffle data for splits + return DatatypeHelper.split_subset_data_by_fold(fold, chunks, 0.8, 0.1, 0.1) + + def _parse_dialogue_act(self, act, domain): + act = ( + act.replace("inform(", "") + .replace("request(", "") + .replace("multiple_choice(", "") + ) + act = act[:-1] + + args = act.split(";") + result = {} + for arg in args: + params = arg.split("=") + key = params[0] + if ( + key == "other" + ): # This is all stuff the E2E model should be able to pick up on its own. + continue + if ( + len(params) == 1 + ): # MSR_E2E has this as a "what explicit information do we want" slot, but it's not super consistent + continue + result[key] = "=".join(params[1:]) + if len(result) > 0: + result[tod.STANDARD_API_NAME_SLOT] = domain + return result + + def _get_utterance_and_api_call_for_speaker(self, speaker, utterances, idx): + utts = [] + slots = {} + while idx < len(utterances): + here = utterances[idx] + if here["speaker"] != speaker: + break + utts.append(here["text"]) + for act in utterances[idx]["dialogue_acts"]: + if speaker == "agent" and not ( + act.startswith("inform") or act.startswith("multiple_choice") + ): + continue + if speaker == "user" and not ( + act.startswith("inform") or act.startswith("request") + ): + continue + slots.update(self._parse_dialogue_act(act, utterances[0]["domain"])) + idx += 1 + return idx, "\n".join(utts), slots + + def setup_episodes(self, fold): + """ + Parses into TodStructuredEpisode. + """ + domains = self.opt.get("msre2e_domains", DOMAINS) + chunks = self._load_data(fold, domains) + domains_cnt = Counter() + episodes = [] + for utterances in chunks: + if len(utterances) < 1: + continue + domain = utterances[0]["domain"] + domains_cnt[domain] += 1 + idx = 0 + rounds = [] + goal_calls = [] + if len(utterances) > 0 and utterances[0]["speaker"] == "agent": + idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( + "agent", utterances, idx + ) + r = tod.TodStructuredRound( + user_utt=tod.CONST_SILENCE, + api_resp_machine=api_resp, + sys_utt=sys_utt, + ) + rounds.append(r) + + cum_api_call = {} + while idx < len(utterances): + idx, user_utt, api_call = self._get_utterance_and_api_call_for_speaker( + "user", utterances, idx + ) + idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( + "agent", utterances, idx + ) + if not self.opt["use_cumulative_api_calls"]: + r = tod.TodStructuredRound( + user_utt=user_utt, + api_call_machine=api_call, + api_resp_machine=api_resp, + sys_utt=sys_utt, + ) + else: + cum_api_call.update(api_call) + r = tod.TodStructuredRound( + user_utt=user_utt, + api_call_machine=copy.deepcopy(cum_api_call) + if len(api_resp) > 0 + else {}, + api_resp_machine=api_resp if len(api_resp) > 0 else {}, + sys_utt=sys_utt, + ) + + rounds.append(r) + if len(api_call) > 0: + goal_calls.append(api_call) + + episode = tod.TodStructuredEpisode( + domain=domain, + api_schemas_machine=SLOT_NAMES[domain], + goal_calls_machine=goal_calls, + rounds=rounds, + delex=self.opt.get("delex", False), + ) + episodes.append(episode) + return episodes + + def get_id_task_prefix(self): + return "MsrE2E" + + def _label_fold(self, chunks): + return chunks.conversation_id.apply(self._h) + + +class SystemTeacher(MsrE2EParser, tod_agents.TodSystemTeacher): + pass + + +class UserSimulatorTeacher(MsrE2EParser, tod_agents.TodUserSimulatorTeacher): + pass + +class DefaultTeacher(SystemTeacher): + pass diff --git a/parlai/tasks/msr_e2e/build.py b/parlai/tasks/msr_e2e/build.py new file mode 100644 index 00000000000..98eb99fd3ed --- /dev/null +++ b/parlai/tasks/msr_e2e/build.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import parlai.core.build_data as build_data +import os +from parlai.core.build_data import DownloadableFile + +ROOT_URL = ( + "https://raw.githubusercontent.com/xiul-msr/e2e_dialog_challenge/master/data/" +) + +RESOURCES = [ + # raw data files + DownloadableFile( + f"{ROOT_URL}/movie_all.tsv", + "movie_all.tsv", + "d2291fd898d8c2d92d7c92affa5601a0561a28f07f6147e9c196c5a573a222d6", + zipped=False, + ), + DownloadableFile( + f"{ROOT_URL}/restaurant_all.tsv", + "restaurant_all.tsv", + "0e297b2ac2e29f9771fed3cd348873b729eb079cc26f8c2333a28247671bdb28", + zipped=False, + ), + DownloadableFile( + f"{ROOT_URL}/taxi_all.tsv", + "taxi_all.tsv", + "6d8ee9719b3d294b558eb53516c897108d1276e9dbcac0101d4e19a2ad801d20", + zipped=False, + ), +] + + +def build(opt): + # get path to data directory + dpath = os.path.join(opt["datapath"], "msr_e2e") + # define version if any + version = "1.0" + + # check if data had been previously built + if not build_data.built(dpath, version_string=version): + print("[building data: " + dpath + "]") + + # make a clean directory if needed + if build_data.built(dpath): + # an older version exists, so remove these outdated files. + build_data.remove_dir(dpath) + build_data.make_dir(dpath) + + # Download the data. + for downloadable_file in RESOURCES: + downloadable_file.download_file(dpath) + + # mark the data as built + build_data.mark_done(dpath, version_string=version) diff --git a/parlai/tasks/msr_e2e/test.py b/parlai/tasks/msr_e2e/test.py new file mode 100644 index 00000000000..d8d247836f8 --- /dev/null +++ b/parlai/tasks/msr_e2e/test.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from parlai.utils.testing import AutoTeacherTest + + +class TestSystemTeacher(AutoTeacherTest): + task = "msr_e2e:SystemTeacher" + + +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "msr_e2e:UserSimulatorTeacher" diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml new file mode 100644 index 00000000000..9fd946c1b75 --- /dev/null +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml @@ -0,0 +1,71 @@ +acts: +- - domain: taxi + episode_done: false + eval_labels: + - 'APIS: ' + id: MsrE2E_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: taxi + episode_done: false + eval_labels: + - 'APICALL: api_name = taxi ; date = tomorrow ; dropoff_location = 2090 Woodwinds + Dr Woodbury ; greeting = Hello ; numberofpeople = 2 ; pickup_location = 525 + Portland Ave Minneapolis ; pickup_time = 9am ; state = MN' + id: MsrE2E_SystemTeacher + slots: + api_name: taxi + date: tomorrow + dropoff_location: 2090 Woodwinds Dr Woodbury + greeting: Hello + numberofpeople: '2' + pickup_location: 525 Portland Ave Minneapolis + pickup_time: 9am + state: MN + text: 'USER: Hello. Can you please find a taxi for 2 at 9 am tomorrow from 525 + Portland Ave Minneapolis, MN 55415 to .2090 Woodwinds Dr Woodbury, MN 55125. + Thanks' + type: 'APICALL: ' +- - domain: taxi + episode_done: false + eval_labels: + - 'SYSTEM: That''d be $20-27 with uberX. Would you like to proceed?' + id: MsrE2E_SystemTeacher + slots: + api_name: taxi + car_type: Uberx + cost: $20-27 + text: 'APIRESP: api_name = taxi ; car_type = Uberx ; cost = $20-27' + type: 'SYSTEM: ' +- - domain: taxi + episode_done: false + eval_labels: + - 'APICALL: api_name = taxi ; date = tomorrow ; dropoff_location = 2090 Woodwinds + Dr Woodbury ; greeting = Hello ; numberofpeople = 2 ; pickup_location = 525 + Portland Ave Minneapolis ; pickup_time = 9am ; state = MN' + id: MsrE2E_SystemTeacher + slots: + api_name: taxi + date: tomorrow + dropoff_location: 2090 Woodwinds Dr Woodbury + greeting: Hello + numberofpeople: '2' + pickup_location: 525 Portland Ave Minneapolis + pickup_time: 9am + state: MN + text: 'USER: Yes please' + type: 'APICALL: ' +- - domain: taxi + episode_done: false + eval_labels: + - 'SYSTEM: Your UberX ride was confirmed, have a safe trip.' + id: MsrE2E_SystemTeacher + slots: + api_name: taxi + car_type: UberX + closing: have a safe trip + text: 'APIRESP: api_name = taxi ; car_type = UberX ; closing = have a safe trip' + type: 'SYSTEM: ' +num_episodes: 1011 +num_examples: 10705 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml new file mode 100644 index 00000000000..fd643a0d2cb --- /dev/null +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml @@ -0,0 +1,56 @@ +acts: +- - domain: taxi + episode_done: false + id: MsrE2E_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: taxi + episode_done: false + id: MsrE2E_SystemTeacher + labels: + - 'APICALL: ' + slots: {} + text: 'USER: i want to eat sushi in dallas. Where should I go?' + type: 'APICALL: ' +- - domain: taxi + episode_done: false + id: MsrE2E_SystemTeacher + labels: + - 'SYSTEM: I''m here to help you book a taxi. Can I help with that?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: taxi + episode_done: false + id: MsrE2E_SystemTeacher + labels: + - 'APICALL: api_name = taxi ; dropoff_location = mckinney ; numberofpeople = 1 + ; pickup_location = dallas airport ; pickup_location_city = dallas ; pickup_time + = right now ; state = texas' + slots: + api_name: taxi + dropoff_location: mckinney + numberofpeople: '1' + pickup_location: dallas airport + pickup_location_city: dallas + pickup_time: right now + state: texas + text: 'USER: yeah. Can I get a ride from the dallas airport to mckinney, texas? + Just me, right now!' + type: 'APICALL: ' +- - domain: taxi + episode_done: false + id: MsrE2E_SystemTeacher + labels: + - 'SYSTEM: Sure. That''s $37-48 for uberX. Shall I book that?' + slots: + api_name: taxi + car_type: uberX + cost: $37-48 + text: 'APIRESP: api_name = taxi ; car_type = uberX ; cost = $37-48' + type: 'SYSTEM: ' +num_episodes: 8068 +num_examples: 84524 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml new file mode 100644 index 00000000000..aec69b50c10 --- /dev/null +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml @@ -0,0 +1,55 @@ +acts: +- - domain: movie + episode_done: false + eval_labels: + - 'APIS: ' + id: MsrE2E_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: movie + episode_done: false + eval_labels: + - 'APICALL: api_name = movie ; city = Los Angeles ; date = Tuesday ; distanceconstraints + = closest ; moviename = The Brothers Grimsby ; theater = Baldwin Hills Crenshaw' + id: MsrE2E_SystemTeacher + slots: + api_name: movie + city: Los Angeles + date: Tuesday + distanceconstraints: closest + moviename: The Brothers Grimsby + theater: Baldwin Hills Crenshaw + text: 'USER: On Tuesday I''m hopefully going to see The Brothers Grimsby in Los + Angeles. I think Baldwin Hills Crenshaw is the theater closest to me' + type: 'APICALL: ' +- - domain: movie + episode_done: false + eval_labels: + - 'SYSTEM: It has a 2:30pm and 8:10pm show on Tue. Which one would you prefer?' + id: MsrE2E_SystemTeacher + slots: + api_name: movie + date: Tue + starttime: '{2:30pm#8:10pm}' + text: 'APIRESP: api_name = movie ; date = Tue ; starttime = {2:30pm#8:10pm}' + type: 'SYSTEM: ' +- - domain: movie + episode_done: false + eval_labels: + - 'APICALL: ' + id: MsrE2E_SystemTeacher + slots: {} + text: 'USER: I think the 8:10 would work best, can you book that one for me? I + need two tickets' + type: 'APICALL: ' +- - domain: movie + episode_done: false + eval_labels: + - 'SYSTEM: Great I can book those for you.' + id: MsrE2E_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 1008 +num_examples: 10572 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..47d38c225e3 --- /dev/null +++ b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml @@ -0,0 +1,50 @@ +acts: +- - domain: taxi + episode_done: false + eval_labels: + - 'USER: Hello. Can you please find a taxi for 2 at 9 am tomorrow from 525 Portland + Ave Minneapolis, MN 55415 to .2090 Woodwinds Dr Woodbury, MN 55125. Thanks' + id: MsrE2E_UserSimulatorTeacher + text: 'GOAL: api_name = taxi ; date = tomorrow ; dropoff_location = 2090 Woodwinds + Dr Woodbury ; greeting = Hello ; numberofpeople = 2 ; pickup_location = 525 + Portland Ave Minneapolis ; pickup_time = 9am ; state = MN' + type: 'USER: ' +- - domain: taxi + episode_done: false + eval_labels: + - 'USER: Yes please' + id: MsrE2E_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: That''d be $20-27 with uberX. Would you like to proceed?' + type: 'USER: ' +- - domain: taxi + episode_done: true + eval_labels: + - 'USER: [DONE]' + id: MsrE2E_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Your UberX ride was confirmed, have a safe trip.' + type: 'USER: ' +- - domain: restaurant + episode_done: false + eval_labels: + - 'USER: I need to book a table of four at the elephant bar tonight in bakersfield + california' + id: MsrE2E_UserSimulatorTeacher + text: 'GOAL: api_name = restaurant ; city = bakersfield ; date = tonight ; numberofpeople + = 4 ; restaurantname = elephant bar ; state = california | api_name = restaurant + ; city = bakersfield ; state = california | api_name = restaurant ; numberofpeople + = 2 ; starttime = 6:30pm | api_name = restaurant ; personfullname = Donald Drumph' + type: 'USER: ' +- - domain: restaurant + episode_done: false + eval_labels: + - 'USER: are there any good places in bakersfield california?' + id: MsrE2E_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: That restaurant is not in the network, sorry. + + Can I help you with something else?' + type: 'USER: ' +num_episodes: 1011 +num_examples: 10705 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..8046f82b47e --- /dev/null +++ b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml @@ -0,0 +1,48 @@ +acts: +- - domain: taxi + episode_done: false + id: MsrE2E_UserSimulatorTeacher + labels: + - 'USER: i want to eat sushi in dallas. Where should I go?' + text: 'GOAL: api_name = taxi ; dropoff_location = in dallas | api_name = taxi + ; dropoff_location = mckinney ; numberofpeople = 1 ; pickup_location = dallas + airport ; pickup_location_city = dallas ; pickup_time = right now ; state = + texas | api_name = taxi ; name = tom thumb' + type: 'USER: ' +- - domain: taxi + episode_done: false + id: MsrE2E_UserSimulatorTeacher + labels: + - 'USER: yeah. Can I get a ride from the dallas airport to mckinney, texas? Just + me, right now!' + slots: {} + text: 'SYSTEM: I''m here to help you book a taxi. Can I help with that?' + type: 'USER: ' +- - domain: taxi + episode_done: false + id: MsrE2E_UserSimulatorTeacher + labels: + - 'USER: yes. my name is tom thumb' + slots: {} + text: 'SYSTEM: Sure. That''s $37-48 for uberX. Shall I book that?' + type: 'USER: ' +- - domain: taxi + episode_done: true + id: MsrE2E_UserSimulatorTeacher + labels: + - 'USER: [DONE]' + slots: {} + text: 'SYSTEM: You''re all set, Tom. Thanks for using our service!' + type: 'USER: ' +- - domain: movie + episode_done: false + id: MsrE2E_UserSimulatorTeacher + labels: + - 'USER: Hey there. I want to go watch Deadpool tomorrow. Can you help me find + tickets?' + text: 'GOAL: api_name = movie ; date = tomorrow ; moviename = Deadpool | api_name + = movie ; city = Glendale ; numberofpeople = 2 ; starttime = 12PM to 11PM ; + state = California ; theater = any theater' + type: 'USER: ' +num_episodes: 8068 +num_examples: 84524 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..32ff24e5502 --- /dev/null +++ b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml @@ -0,0 +1,48 @@ +acts: +- - domain: movie + episode_done: false + eval_labels: + - 'USER: On Tuesday I''m hopefully going to see The Brothers Grimsby in Los Angeles. + I think Baldwin Hills Crenshaw is the theater closest to me' + id: MsrE2E_UserSimulatorTeacher + text: 'GOAL: api_name = movie ; city = Los Angeles ; date = Tuesday ; distanceconstraints + = closest ; moviename = The Brothers Grimsby ; theater = Baldwin Hills Crenshaw + | api_name = movie ; numberofpeople = 3 ; starttime = 8:10' + type: 'USER: ' +- - domain: movie + episode_done: false + eval_labels: + - 'USER: I think the 8:10 would work best, can you book that one for me? I need + two tickets' + id: MsrE2E_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: It has a 2:30pm and 8:10pm show on Tue. Which one would you prefer?' + type: 'USER: ' +- - domain: movie + episode_done: false + eval_labels: + - 'USER: I know you can. Next agent please report this' + id: MsrE2E_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Great I can book those for you.' + type: 'USER: ' +- - domain: movie + episode_done: true + eval_labels: + - 'USER: [DONE]' + id: MsrE2E_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: You are welcome. Enjoy the show.' + type: 'USER: ' +- - domain: taxi + episode_done: false + eval_labels: + - 'USER: Can you look for a ride from Philadelphia airport (PHL) to Independence + hall Today at 8pm?' + id: MsrE2E_UserSimulatorTeacher + text: 'GOAL: api_name = taxi ; date = Today ; dropoff_location = Independence + hall ; pickup_location = Philadelphia airport ; pickup_time = 8pm | api_name + = taxi ; numberofpeople = 4 | api_name = taxi ; car_type = UberX' + type: 'USER: ' +num_episodes: 1008 +num_examples: 10572 From 61f704112457257a3cb22604e11d7b739d3be74c Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 13:43:04 -0800 Subject: [PATCH 16/57] [TOD][Datasets][Easy] Multidogo -> TOD Conversations format Title. I only include System + UserSimulator Teachers here since that's all we need right now from dataset. Datasets added in this substack: * Google SGD * Google SGD Simulation Splits (In-domain, Out-domain) * MetalWoz * **MSR_E2E** * Multidogo * MultiWoz V2.2 * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/multidogo/__init__.py | 5 + parlai/tasks/multidogo/agents.py | 162 +++++++++ parlai/tasks/multidogo/build.py | 323 ++++++++++++++++++ parlai/tasks/multidogo/test.py | 15 + .../test/multidogo_SystemTeacher_test.yml | 47 +++ .../test/multidogo_SystemTeacher_train.yml | 49 +++ .../test/multidogo_SystemTeacher_valid.yml | 47 +++ .../multidogo_UserSimulatorTeacher_test.yml | 46 +++ .../multidogo_UserSimulatorTeacher_train.yml | 51 +++ .../multidogo_UserSimulatorTeacher_valid.yml | 47 +++ 10 files changed, 792 insertions(+) create mode 100644 parlai/tasks/multidogo/__init__.py create mode 100644 parlai/tasks/multidogo/agents.py create mode 100644 parlai/tasks/multidogo/build.py create mode 100644 parlai/tasks/multidogo/test.py create mode 100644 parlai/tasks/multidogo/test/multidogo_SystemTeacher_test.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/multidogo/__init__.py b/parlai/tasks/multidogo/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/tasks/multidogo/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/tasks/multidogo/agents.py b/parlai/tasks/multidogo/agents.py new file mode 100644 index 00000000000..1e92e073fb4 --- /dev/null +++ b/parlai/tasks/multidogo/agents.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +MultiDoGo implementation for ParlAI. + +NOTE: There is still missing data in the open source version of this; implementation is not complete. See https://github.com/awslabs/multi-domain-goal-oriented-dialogues-dataset/issues/1 +""" + +from typing import Optional +from parlai.core.params import ParlaiParser +import copy +import json +import os +from parlai.core.opt import Opt +from parlai.utils.data import DatatypeHelper +import parlai.core.tod.tod_core as tod +import parlai.core.tod.tod_agents as tod_agents +import parlai.tasks.multidogo.build as build_ +from parlai.tasks.multidogo.build import get_processed_multidogo_folder +from parlai.tasks.multidogo.build import ( + DOMAINS, + SENTENCE_INTENT, + TURN_INTENT, + TURN_AND_SENTENCE_INTENT, +) + +INTENT_ANNOTATION_TYPES = [SENTENCE_INTENT, TURN_INTENT, TURN_AND_SENTENCE_INTENT] + + +class MultidogoParser(tod_agents.TodStructuredDataParser): + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--multidogo-domains", + nargs="+", + default=DOMAINS, + choices=DOMAINS, + help="Uses last passed in configuration.", + ) + parser.add_argument( + "--intent-type", + default=TURN_INTENT, + choices=INTENT_ANNOTATION_TYPES, + help="Sets the type of intent classification labels. Sentence annotations represented as a list with adjacent entries of the same type deduped.", + ) + return parser + + def __init__(self, opt: Opt, shared=None): + self.fold = DatatypeHelper.fold(opt["datatype"]) + self.dpath = os.path.join(opt["datapath"], "multidogo") + opt["datafile"] = self.fold + build_.build(opt) + super().__init__(opt, shared) + + def setup_episodes(self, fold): + result = [] + domains = self.opt.get("multidogo_domains", DOMAINS) + intent_type = self.opt.get("intent-type", TURN_INTENT) + for _conv_id, domain, conversation in self._iterate_over_conversations( + domains, intent_type + ): + if len(conversation) == 0 or not ( + all(["role" in turn for turn in conversation.values()]) + ): + continue + rounds = [] + prev_role = conversation["0"]["role"] + if prev_role == "customer": + user_utt = [conversation["0"]["text"]] + api_call = conversation["0"].get("slots", {}) + api_resp = {} + sys_utt = [] + else: + user_utt = ["__SILENCE__"] + api_call = {} + api_resp = conversation["0"].get("slots", {}) + sys_utt = [conversation["0"]["text"]] + all_calls = api_call + api_call = {tod.STANDARD_API_NAME_SLOT: domain} + for i in range(1, len(conversation)): + turn = conversation[str(i)] + if prev_role == "agent" and prev_role != turn["role"]: + rounds.append( + tod.TodStructuredRound( + user_utt="\n".join(user_utt), + api_call_machine=api_call, + api_resp_machine=api_resp, + sys_utt="\n".join(sys_utt), + ) + ) + user_utt = [] + api_call = {tod.STANDARD_API_NAME_SLOT: domain} + api_resp = {} + sys_utt = [] + prev_role = turn["role"] + slot = turn.get("slots", {}) + if prev_role == "customer": + user_utt.append(turn["text"]) + api_call.update(slot) + all_calls.update(slot) + else: + api_resp.update(slot) + sys_utt.append(turn["text"]) + + rounds.append( + tod.TodStructuredRound( + user_utt=user_utt, + api_call_machine=api_call, + api_resp_machine=api_resp, + sys_utt=sys_utt, + ) + ) + goal_calls = copy.deepcopy(all_calls) + goal_calls[tod.STANDARD_API_NAME_SLOT] = domain + result.append( + tod.TodStructuredEpisode( + domain=domain, + api_schemas_machine=[ + { + tod.STANDARD_API_NAME_SLOT: domain, + tod.STANDARD_OPTIONAL_KEY: all_calls.keys(), + } + ], + goal_calls_machine=[goal_calls], + rounds=rounds, + ) + ) + return result + + def _iterate_over_conversations(self, domains, intent): + for domain in domains: + data_folder = get_processed_multidogo_folder( + self.dpath, domain, self.fold, intent + ) + for filename in os.listdir(data_folder): + if filename.endswith(".json"): + with open(data_folder + "/" + filename) as f: + data = json.load(f) + for conv_id, value in data.items(): + yield conv_id, domain, value + + def get_id_task_prefix(self): + return "Multidogo" + + +class SystemTeacher(MultidogoParser, tod_agents.TodSystemTeacher): + pass + + +class UserSimulatorTeacher(MultidogoParser, tod_agents.TodUserSimulatorTeacher): + pass + + +class DefaultTeacher(SystemTeacher): + pass diff --git a/parlai/tasks/multidogo/build.py b/parlai/tasks/multidogo/build.py new file mode 100644 index 00000000000..29c75298548 --- /dev/null +++ b/parlai/tasks/multidogo/build.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import parlai.core.build_data as build_data +from parlai.core.build_data import DownloadableFile + +import csv +from itertools import islice +from pathlib import Path +import os +import json +import re + +DEBUG_MISSING_RAW_CONVERSATIONS = False # Unnecessary once Amazon fixes multidogo + +RESOURCE = DownloadableFile( + "https://github.com/awslabs/multi-domain-goal-oriented-dialogues-dataset/archive/master.zip", + "raw_data.zip", + "fb59c7261da2d30d9d24b9af309ebb4bf0e5b39f97d718201a7160e591e76a3c", + zipped=True, +) + +RAW_DATA_PREFIX = "multi-domain-goal-oriented-dialogues-dataset-master/data/" + +RAW_DATA_ANNOTATED_DATA_PATH = "paper_splits" +RAW_DATA_UNANNOTATED_DATA_PATH = "unannotated" + +TURN_INTENT = "turn" +SENTENCE_INTENT = "sentence" +TURN_AND_SENTENCE_INTENT = "both" + +RAW_DATA_SENTENCE_INTENT_PATH = "splits_annotated_at_sentence_level" +RAW_DATA_TURN_INTENT_PATH = "splits_annotated_at_turn_level" + +RAW_DATA_INTENT_BY_TYPE_PATH = { + TURN_INTENT: RAW_DATA_TURN_INTENT_PATH, + SENTENCE_INTENT: RAW_DATA_SENTENCE_INTENT_PATH, +} + +DOMAINS = ["airline", "fastfood", "finance", "insurance", "media", "software"] + +DATATYPE_TO_RAW_DATA_FILE_NAME = { + "test": "test.tsv", + "train": "train.tsv", + "valid": "dev.tsv", +} + +PROCESSED = "processed/" + + +def _preprocess(opt, datapath, datatype): + """ + MultiDoGo conversations take place between an "agent" and a customer". Labeled + customer data is stored in one set of files while the agent data is in another. + There is a common conversation ID between the two, but the conversations are not + listed in a consistent way between the documents. Since we'll have to do work to + associate the data between the files anyway, we might as well process the data into + a new file that'll be easier to deal with. + + Stores the data as /processed//.txt. + Will skip preprocessing if this file already exists. + """ + domains = opt.get("domains", DOMAINS) + intent_type = opt.get("intent_type", TURN_INTENT) + + for domain in domains: + # to see which domain/datatype combo we've built, use a dummy file to mark + built_file = _get_processed_multidogo_built_file( + datapath, domain, datatype, intent_type + ) + if os.path.isfile(built_file): + continue + print( + f" Preprocessing '{domain}' data for '{datatype}' with '{intent_type}' intent labels." + ) + + out_dir = get_processed_multidogo_folder( + datapath, domain, datatype, intent_type + ) + Path(out_dir).mkdir(parents=True, exist_ok=True) + + # The agent responses for *all* datatypes are in one file. + # We need to iterate through the datatype file to know which lines + # we'll actually need... so build a quick lookup table to know which + # lines in the tsv file we'll need to care about so we're not scanning + # through the whole thing a bunch + unannotated_id_map = _build_conversation_span_map( + _get_unannotated_tsv_data(datapath, domain) + ) + + # Actually do the work of collating all of the conversations + annotations + # For turn + sentence intent labels, we do two passes, one for sentence + # then one for turn so that we do not add two sets of labels for the + # same conversation ID. We can use this forced structure to do the + # separate categories of turn intent and sentence intent labels. We + # also do a bit of chuking + file_idx = 0 + seen_conversations_set = set() + if intent_type == TURN_AND_SENTENCE_INTENT or intent_type == SENTENCE_INTENT: + file_idx, seen_conversations_set = _aggregate_and_write_conversations( + intent_type, + SENTENCE_INTENT, + datapath, + domain, + datatype, + unannotated_id_map, + start_file_idx=file_idx, + skip_ids=set(), + ) + + if intent_type == TURN_AND_SENTENCE_INTENT or intent_type == TURN_INTENT: + _, _ = _aggregate_and_write_conversations( + intent_type, + TURN_INTENT, + datapath, + domain, + datatype, + unannotated_id_map, + start_file_idx=file_idx, + skip_ids=seen_conversations_set, + ) + + # mark that we've built this combinations + open(built_file, "a").close() + + +def get_processed_multidogo_folder(datapath, domain, datatype, intent_type): + return os.path.join(datapath, PROCESSED, domain, intent_type, datatype) + + +def _get_processed_multidogo_built_file(datapath, domain, datatype, intent_type): + return os.path.join( + get_processed_multidogo_folder(datapath, domain, datatype, intent_type), + ".build", + ) + + +# unannotated data is UNANNOTATED_DATA_PROFIX + + '.tsv' +# annotated data is ANNOTATED_DATA_PATH + + + '/' + + '.tsv' +def _get_unannotated_tsv_data(datapath, domain): + file_name = os.path.join( + datapath, RAW_DATA_PREFIX, RAW_DATA_UNANNOTATED_DATA_PATH, domain + ".tsv" + ) + return csv.reader(open(file_name, "r"), delimiter=",") # comma-separated tsv, lol + + +def _get_annotated_tsv_data(datapath, domain, datatype, annotation_type): + file_name = os.path.join( + datapath, + RAW_DATA_PREFIX, + RAW_DATA_ANNOTATED_DATA_PATH, + RAW_DATA_INTENT_BY_TYPE_PATH[annotation_type], + domain, + DATATYPE_TO_RAW_DATA_FILE_NAME[datatype], + ) + return csv.reader(open(file_name, "r"), delimiter="\t") + + +def _build_conversation_span_map(unannotated_tsv_object): + result = {} # conversationId to (start line, length) map + start = 0 + prev_conversation_id = "" + length = 0 + for i, row in enumerate(unannotated_tsv_object): + conversation_id = row[0][ + 4:-2 + ] # do substring cause conversationId has extra filler in unannotated + if conversation_id != prev_conversation_id: + result[prev_conversation_id] = (start, length) + start = i + prev_conversation_id = conversation_id + length = 0 + length += 1 + result[conversation_id] = (start, length) + return result + + +def _get_slots_map(utterance, slot_string): + values = slot_string.split(" ") + cleaned = re.sub(r"[^\w\s]", "", utterance) + words = cleaned.split(" ") + result = {} + for i in range(len(words)): + if values[i] != "O": + result[values[i]] = words[i] + return result + + +def _aggregate_and_write_conversations( + raw_intent_type, + fetch_intent_type, + datapath, + domain, + datatype, + unannotated_id_map, + skip_ids, + start_file_idx=0, +): + conversations_to_write = {} # conversationId -> list of turns + seen_conversations = set() + out_dir = get_processed_multidogo_folder( + datapath, domain, datatype, raw_intent_type + ) + file_idx = start_file_idx + intent_tsv = _get_annotated_tsv_data(datapath, domain, datatype, fetch_intent_type) + next(intent_tsv) # don't need the header in the first line + for labeled_line in intent_tsv: + conversation_id = labeled_line[0] + if conversation_id in skip_ids: + continue + if conversation_id not in seen_conversations: + # new conversation, add text of conversation to conversations_to_write + conversations_to_write[conversation_id] = {} + found_raw_conversation = _add_utterances( + unannotated_id_map, + conversation_id, + conversations_to_write, + datapath, + domain, + ) + seen_conversations.add(conversation_id) + if not found_raw_conversation: + if DEBUG_MISSING_RAW_CONVERSATIONS: + print(f"Could not find raw conversations for {conversation_id}") + skip_ids.add(conversation_id) + conversations_to_write.pop(conversation_id, None) + continue + if fetch_intent_type == SENTENCE_INTENT: + _get_sentence_labels_and_slots_map(labeled_line, conversations_to_write) + elif fetch_intent_type == TURN_INTENT: + _get_turn_labels_and_slots_map(labeled_line, conversations_to_write) + else: + raise KeyError( + "Invalid `fetch_intent_type`. This case should never be hit. Something is broken in the `build.py` file." + ) + # Don't forget to dump out last file + with open(f"{out_dir}/{file_idx}.json", "w+") as out_file: + json.dump(conversations_to_write, out_file, indent=4) + file_idx += 1 + # Return necessary outputs for next pass + return file_idx, seen_conversations + + +def _add_utterances( + unannotated_id_map, conversation_id, conversations_to_write, datapath, domain +): + try: + start, length = unannotated_id_map[conversation_id] + except KeyError: + return False + conversation_text = islice( + _get_unannotated_tsv_data(datapath, domain), start, start + length + ) + + for line in conversation_text: + # Format of unannotated: conversationId,turnNumber,utteranceId,utterance,authorRole + conversations_to_write[conversation_id] = { + **conversations_to_write[conversation_id], + int(line[1]): {"text": line[3], "role": line[4]}, + } + return True + + +def _get_sentence_labels_and_slots_map(labeled_line, output): + # Sentence tsv format: conversationId turnNumber sentenceNumber utteranceId utterance slot-labels intent + conversation_id = labeled_line[0] + turn_number = int(float(labeled_line[1])) # cause a few got saved as float. + if conversation_id not in output: + raise RuntimeError("Should never happen; raw conversation text should be here") + if turn_number not in output[conversation_id]: + output[conversation_id][turn_number] = {} + output[conversation_id][turn_number] = { + **output[conversation_id][turn_number], + "slots": _get_slots_map(labeled_line[4], labeled_line[5]), + } + if "intents" not in output[conversation_id][turn_number]: + output[conversation_id][turn_number]["intents"] = [] + output[conversation_id][turn_number]["intents"].append(labeled_line[6]) + + +def _get_turn_labels_and_slots_map(labeled_line, output): + # Turn tsv format: conversationId turnNumber utteranceId utterance slot-labels intent + conversation_id = labeled_line[0] + turn_number = int(float(labeled_line[1])) # cause a few got saved as float + if conversation_id not in output: + raise RuntimeError("Should never happen; raw conversation text should be here") + if turn_number not in output[conversation_id]: + output[conversation_id][turn_number] = {} + output[conversation_id][turn_number] = { + **output[conversation_id][turn_number], + "slots": _get_slots_map(labeled_line[3], labeled_line[4]), + "intents": [labeled_line[5]], + } + + +def build(opt): + # get path to data directory + datapath = os.path.join(opt["datapath"], "multidogo") + # define version if any + version = "v1.1" + + # check if data had been previously downloaded + if not build_data.built(datapath, version_string=version): + print("[building data: " + datapath + "]") + + # make a clean directory if needed + if build_data.built(datapath): + # an older version exists, so remove these outdated files. + build_data.remove_dir(datapath) + build_data.make_dir(datapath) + + # Download the data. + RESOURCE.download_file(datapath) + + # mark the data as built + build_data.mark_done(datapath, version_string=version) + + # do preprocessing on the data to put it into FBDialogueData format + for fold in ["train", "valid", "test"]: + _preprocess(opt, datapath, fold) diff --git a/parlai/tasks/multidogo/test.py b/parlai/tasks/multidogo/test.py new file mode 100644 index 00000000000..9f3889a2263 --- /dev/null +++ b/parlai/tasks/multidogo/test.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from parlai.utils.testing import AutoTeacherTest + + +class TestSystemTeacher(AutoTeacherTest): + task = "multidogo:SystemTeacher" + + +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "multidogo:UserSimulatorTeacher" diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_test.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_test.yml new file mode 100644 index 00000000000..ffe539d83ad --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_test.yml @@ -0,0 +1,47 @@ +acts: +- - domain: airline + episode_done: false + eval_labels: + - 'APIS: ' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: airline + episode_done: false + eval_labels: + - 'APICALL: api_name = airline' + id: Multidogo_SystemTeacher + slots: + api_name: airline + text: 'USER: HELLO ROBIN' + type: 'APICALL: ' +- - domain: airline + episode_done: false + eval_labels: + - 'SYSTEM: Hello! Good morning. You''ve reached LMT Airways. How may I assist + you today?' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: airline + episode_done: false + eval_labels: + - 'APICALL: api_name = airline' + id: Multidogo_SystemTeacher + slots: + api_name: airline + text: 'USER: I NEED BOARDING PASS ' + type: 'APICALL: ' +- - domain: airline + episode_done: false + eval_labels: + - 'SYSTEM: Awesome! I''d be glad to help you with that. May I know your last name + please?' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 2316 +num_examples: 43104 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml new file mode 100644 index 00000000000..c86114ed7fa --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml @@ -0,0 +1,49 @@ +acts: +- - domain: airline + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: airline + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'APICALL: api_name = airline' + slots: + api_name: airline + text: 'USER: __SILENCE__' + type: 'APICALL: ' +- - domain: airline + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'SYSTEM: Welcome to High flying customer service! You''re connected to our customer + associate! Good morning! My name is Sam, How may I help you?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: airline + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'APICALL: api_name = airline' + slots: + api_name: airline + text: 'USER: HI,GOOD MORNING + + I WANTS TO BOOK A TICKET FOR FLIGHT' + type: 'APICALL: ' +- - domain: airline + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to please + share your details with me. May I know your full name please?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 15616 +num_examples: 290050 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml new file mode 100644 index 00000000000..c1505d9821d --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml @@ -0,0 +1,47 @@ +acts: +- - domain: airline + episode_done: false + eval_labels: + - 'APIS: ' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: airline + episode_done: false + eval_labels: + - 'APICALL: api_name = airline' + id: Multidogo_SystemTeacher + slots: + api_name: airline + text: 'USER: HI GOOD MORNING' + type: 'APICALL: ' +- - domain: airline + episode_done: false + eval_labels: + - 'SYSTEM: Welcome to High flying customer service! You''re connected to our customer + associate! Good morning! My name is Sam, How may I help you?' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: airline + episode_done: false + eval_labels: + - 'APICALL: api_name = airline' + id: Multidogo_SystemTeacher + slots: + api_name: airline + text: 'USER: I WANT TO BOOK A TICKET IN FLIGHT' + type: 'APICALL: ' +- - domain: airline + episode_done: false + eval_labels: + - 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to please + share your details with me. May I know your full name please?' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 1590 +num_examples: 29662 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..8c408905467 --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml @@ -0,0 +1,46 @@ +acts: +- - domain: airline + episode_done: false + eval_labels: + - 'USER: HELLO ROBIN' + id: Multidogo_UserSimulatorTeacher + text: 'GOAL: api_name = airline ; booking_confirmation_number = 523 ; email_address + = gmailcom ; name = mohan' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: I NEED BOARDING PASS ' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Hello! Good morning. You''ve reached LMT Airways. How may I assist + you today?' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: MOHAN' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Awesome! I''d be glad to help you with that. May I know your last + name please?' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: CONFIRMATION NUMBER : moh523' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Alright Mohan! Could you please share the booking confirmation + number?' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: Mohan283@gmail.com' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Great! May I have your email address please?' + type: 'USER: ' +num_episodes: 2316 +num_examples: 43104 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..f42028106f2 --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml @@ -0,0 +1,51 @@ +acts: +- - domain: airline + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: __SILENCE__' + text: 'GOAL: api_name = airline ; arrival_city = singapore ; departure_city = + thailand ; email_address = kavigmailcom ; name = kavisri ; number_of_passengers + = five' + type: 'USER: ' +- - domain: airline + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: HI,GOOD MORNING + + I WANTS TO BOOK A TICKET FOR FLIGHT' + slots: {} + text: 'SYSTEM: Welcome to High flying customer service! You''re connected to our + customer associate! Good morning! My name is Sam, How may I help you?' + type: 'USER: ' +- - domain: airline + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: KAVISRI' + slots: {} + text: 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to + please share your details with me. May I know your full name please?' + type: 'USER: ' +- - domain: airline + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: THAILAND AND SINGAPORE' + slots: {} + text: 'SYSTEM: It''s nice meeting you kavisri! Could you please share your departure + and arrival city?' + type: 'USER: ' +- - domain: airline + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: OK' + slots: {} + text: 'SYSTEM: Perfect! I hope you enjoy the trip! As I''ve checked with my system + here, there is one flight of Jet airways operating on 09/20/2018, The timings + are, 6:00 Am to 8:00 Am and it is costing you $170 per head. ' + type: 'USER: ' +num_episodes: 15616 +num_examples: 290050 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..410ff0590c3 --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml @@ -0,0 +1,47 @@ +acts: +- - domain: airline + episode_done: false + eval_labels: + - 'USER: HI GOOD MORNING' + id: Multidogo_UserSimulatorTeacher + text: 'GOAL: api_name = airline ; arrival_city = chennai ; departure_city = mumbai + ; email_address = gmailcom ; name = viswa ; start_date = 92018' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: I WANT TO BOOK A TICKET IN FLIGHT' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Welcome to High flying customer service! You''re connected to our + customer associate! Good morning! My name is Sam, How may I help you?' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: VISWA ' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to + please share your details with me. May I know your full name please?' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: MUMBAI TO CHENNAI' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Great! It''s nice meeting you Viswa! Could you please share your + departure and arrival city?' + type: 'USER: ' +- - domain: airline + episode_done: false + eval_labels: + - 'USER: 09/20/2018' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: That''s amazing! I recently visited chennai, It''s such a beautiful + place! I hope you enjoy the trip! May I know your preferred date please?' + type: 'USER: ' +num_episodes: 1590 +num_examples: 29662 From e8efc52f166b16f2e076594d0cd52de42b6033c3 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 13:54:03 -0800 Subject: [PATCH 17/57] [TOD][Datasets][Easyish] MultiWoz V2.2 in Conversations Format Title. I only include System + UserSimulator Teachers here since that's all we need right now from dataset. There are so many versions of MultiWoz, but this one is closest to our simulator. --------------------------------- Datasets added in this substack: * Google SGD * Google SGD Simulation Splits (In-domain, Out-domain) * MetalWoz * **MSR_E2E** * Multidogo * MultiWoz V2.2 * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/multiwoz_v22/README.md | 12 + parlai/tasks/multiwoz_v22/__init__.py | 5 + parlai/tasks/multiwoz_v22/agents.py | 370 ++++++++++++++++++ parlai/tasks/multiwoz_v22/build.py | 226 +++++++++++ .../multiwoz_v22/build_sha_check_script.py | 81 ++++ parlai/tasks/multiwoz_v22/test.py | 15 + ...multiwoz_v22_UserSimulatorTeacher_test.yml | 50 +++ ...ultiwoz_v22_UserSimulatorTeacher_train.yml | 51 +++ ...ultiwoz_v22_UserSimulatorTeacher_valid.yml | 50 +++ .../multiwoz_v22/test/multiwoz_v22_test.yml | 66 ++++ .../multiwoz_v22/test/multiwoz_v22_train.yml | 73 ++++ .../multiwoz_v22/test/multiwoz_v22_valid.yml | 59 +++ 12 files changed, 1058 insertions(+) create mode 100644 parlai/tasks/multiwoz_v22/README.md create mode 100644 parlai/tasks/multiwoz_v22/__init__.py create mode 100644 parlai/tasks/multiwoz_v22/agents.py create mode 100644 parlai/tasks/multiwoz_v22/build.py create mode 100644 parlai/tasks/multiwoz_v22/build_sha_check_script.py create mode 100644 parlai/tasks/multiwoz_v22/test.py create mode 100644 parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml create mode 100644 parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml create mode 100644 parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml create mode 100644 parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml diff --git a/parlai/tasks/multiwoz_v22/README.md b/parlai/tasks/multiwoz_v22/README.md new file mode 100644 index 00000000000..642f4c2aef6 --- /dev/null +++ b/parlai/tasks/multiwoz_v22/README.md @@ -0,0 +1,12 @@ +Task: Multiwoz v2.2 +=============== +Description: Version of Multiwoz 2.0 dataset as cleaned by Google and structured to the Schema-Guided Dataset format. + +From https://github.com/budzianowski/multiwoz/tree/master/data/MultiWOZ\_2.2 + +License: MIT + +Link: https://aclanthology.org/2020.nlp4convai-1.13.pdf + +Tags: #TOD #Multiwoz + diff --git a/parlai/tasks/multiwoz_v22/__init__.py b/parlai/tasks/multiwoz_v22/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/tasks/multiwoz_v22/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/tasks/multiwoz_v22/agents.py b/parlai/tasks/multiwoz_v22/agents.py new file mode 100644 index 00000000000..e59942d1f3a --- /dev/null +++ b/parlai/tasks/multiwoz_v22/agents.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +implementation for ParlAI. +""" + +from parlai.core.params import ParlaiParser +import copy +import os +import pandas as pd +from parlai.core.opt import Opt +import parlai.core.tod.tod_core as tod +import json +from typing import Optional +from parlai.utils.data import DatatypeHelper +from parlai.utils.io import PathManager + +import parlai.tasks.multiwoz_v22.build as build_ +import parlai.core.tod.tod_agents as tod_agents + + +DOMAINS = [ + "attraction", + "bus", + "hospital", + "hotel", + "police", + "restaurant", + "taxi", + "train", +] + +WELL_FORMATTED_DOMAINS = [ + "attraction", + "bus", + "hotel", + "restaurant", + "train", + "taxi", +] + + +class MultiwozV22Parser(tod_agents.TodStructuredDataParser): + """ + Abstract data loader for Multiwoz V2.2 into TOD structured data format. + + Multiwoz 2.2 has 'find' and 'book' as the only intents. + + For API calls, we look for turns that are not 'NONE' `active_intents` in the USER's turn state. We then filter these for whether or not the SYSTSEM has actually made an api call by looking in the dialogue act of the SYSTEM turn. + * For 'find' intents, we make an API call if it does an "Inform" or gives a "NoOffer". We look in the corresponding `.db` file to return the relevant information. + * For 'book' intents, we make an API call if the SYSTEM's dialogue act includes booking and then offer the slots/values of that key as the API response. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + group = parser.add_argument_group("Multiwoz args") + group.add_argument( + "--well-formatted-domains-only", + type=bool, + default=True, + help="Some of the domains in Multiwoz are not super well formatted. Use only the well formatted ones.", + ) + group.add_argument( + "--dialogue-id", + type=str, + default="", + help="If non-empty, filters for a particular dialogue id", + ) + return super().add_cmdline_args(parser, partial_opt) + + def __init__(self, opt: Opt, shared=None): + self.fold = DatatypeHelper.fold(opt["datatype"]) + opt["datafile"] = self.fold + self.dpath = os.path.join(opt["datapath"], "multiwoz_v22") + build_.build(opt) + self.last_call = {} + super().__init__(opt, shared) + + def load_schemas(self): + with PathManager.open(os.path.join(self.dpath, "schema.json")) as f: + raw = json.load(f) + result = {} + for service in raw: + domain = service["service_name"] + prefix_end_idx = len(domain) + 1 + all_slots = set([x["name"][prefix_end_idx:] for x in service["slots"]]) + + for intent in service["intents"]: + call_name = intent["name"] + result[call_name] = { + tod.STANDARD_API_NAME_SLOT: call_name, + } + req_slots = set([x[prefix_end_idx:] for x in intent["required_slots"]]) + if len(req_slots) > 0: + result[call_name][tod.STANDARD_REQUIRED_KEY] = list(req_slots) + + # Not fully trusting the original schema data... + optional_slots = set( + [x[prefix_end_idx:] for x in intent["optional_slots"].keys()] + ) + optional_slots = optional_slots | all_slots + optional_slots = optional_slots - req_slots + if len(optional_slots) > 0: + result[call_name][tod.STANDARD_OPTIONAL_KEY] = list(optional_slots) + + if domain == "police": # Multiwoz 2.2 only lists "police" + result["find_police"] = { + tod.STANDARD_OPTIONAL_KEY: list(all_slots), + tod.STANDARD_API_NAME_SLOT: "find_police", + } + if ( + domain == "taxi" + ): # Multiwoz 2.2 has "book taxi" in the schema but it's "find taxi" in the data... + result["find_taxi"] = copy.deepcopy(result["book_taxi"]) + result["find_taxi"][tod.STANDARD_API_NAME_SLOT] = "find_taxi" + return result + + def load_dbs(self): + dbs = {} + for key in DOMAINS: + if ( + key == "hospital" + ): # has funky extra format, so we're gonna deal with it manually. + with PathManager.open( + os.path.join(self.dpath, "db", key + "_db.json") + ) as f: + file_lines = f.readlines() + hospital_address_lines = file_lines[1:4] + partial = [ + x.replace("#", "").strip().lower().split(":") + for x in hospital_address_lines + ] + self.hospital_address = {x[0]: x[1] for x in partial} + self.hospital_department_details = json.loads("".join(file_lines[6:])) + continue + if ( + key == "taxi" + ): # Taxi domain is funky and the db for it is just response slot options. + continue + with PathManager.open( + os.path.join(self.dpath, "db", key + "_db.json") + ) as f: + blob = json.load(f) + for i, entry in enumerate(blob): + cased = {} + for slot_name in entry: + cased[slot_name.lower().replace(" ", "")] = entry[slot_name] + blob[i] = cased + dbs[key] = pd.DataFrame.from_dict(blob) + + return dbs + + def load_chunks(self, fold): + if fold == "valid": + fold = "dev" # change name to match file structure + for path in PathManager.ls(os.path.join(self.dpath, fold)): + with PathManager.open(os.path.join(self.dpath, fold, path)) as f: + blob = json.load(f) + for convo in blob: + yield convo + + def _get_find_api_response(self, intent, raw_slots, sys_dialog_act): + """ + Get an API response out of the lookup databases. + """ + domain = "" + for cand in DOMAINS: + if cand in intent: + domain = cand + if domain == "taxi": # handle separately cause funky + for action in sys_dialog_act: + if action == "Taxi-Inform": + return {x[0]: x[1] for x in sys_dialog_act[action]} + return {domain: domain} # too much work to do this right... + if domain == "hospital": # handle separately cause funky + res = self.hospital_address + if "hospital-department" in raw_slots: + for blob in self.hospital_department_details: + if blob["department"] in raw_slots["hospital-department"]: + res[blob["department"]] = blob + return res + slots = {} + for raw_key in raw_slots: + key = raw_key[len(domain + "-") :] + slots[key] = raw_slots[raw_key] + for action in sys_dialog_act: + if "Recommend" in action: + add_slots = {} + for x in sys_dialog_act[action]: + name = x[0] + val = x[1] + if self._slot_in_schema(name, intent): + if name not in add_slots: + add_slots[name] = [] + add_slots[name].append(val) + for key in add_slots: + slots[key] = add_slots[key] + + find = self.dbs[domain] + for slot, values in slots.items(): + if slot == "arriveby": + condition = find[slot] < values[0] + elif slot == "leaveat": + condition = find[slot] > values[0] + else: + condition = find[slot].isin(values) + + find = find[condition] + + filtered = self.dbs[domain].iloc[find.index] + count = len(filtered.index) + if count == 0: + return {} + blob = filtered.head(1).to_dict('records') + + results = {} + results["COUNT"] = count + results["OPTIONS"] = json.dumps(blob) + return results + + def _slot_in_schema(self, slot, intent): + return slot in self.schemas[intent].get( + tod.STANDARD_OPTIONAL_KEY, [] + ) or slot in self.schemas[intent].get(tod.STANDARD_REQUIRED_KEY, []) + + def _get_round(self, dialogue_id, raw_episode, turn_id): + """ + Parse to TodStructuredRound. + + Assume User turn first. + """ + user_turn = raw_episode[turn_id] + if user_turn["speaker"] != "USER": + raise RuntimeError( + f"Got non-user turn when it should have been in {dialogue_id}; turn id {turn_id}" + ) + sys_turn = raw_episode[turn_id + 1] + sys_dialog_act = self.dialog_acts[dialogue_id][str(turn_id + 1)]["dialog_act"] + if sys_turn["speaker"] != "SYSTEM": + raise RuntimeError( + f"Got non-system turn when it should have been in {dialogue_id}; turn id {turn_id}" + ) + frames = user_turn.get("frames", []) + call = {} + resp = {} + for frame in frames: + if frame.get("state", {}).get("active_intent", "NONE") != "NONE": + intent = frame["state"]["active_intent"] + domain = frame["service"] + maybe_call_raw = copy.deepcopy(frame["state"]["slot_values"]) + maybe_call = {} + truncate_length = len(domain) + 1 + for key in maybe_call_raw: + maybe_call[key[truncate_length:]] = maybe_call_raw[key][0] + maybe_call[tod.STANDARD_API_NAME_SLOT] = intent + if "find" in intent: + for key in sys_dialog_act: + if "Inform" in key or "NoOffer" in key: + # Gotta check to make sure if it's inform, that it's about the right topic + if "Inform" in key: + valid = True + slots = [x[0] for x in sys_dialog_act[key]] + for slot in slots: + valid &= self._slot_in_schema(slot, intent) | ( + slot == "choice" + ) + if not valid: + continue + call = maybe_call + resp = self._get_find_api_response( + intent, frame["state"]["slot_values"], sys_dialog_act + ) + elif "book" in intent: + for key in sys_dialog_act: + if "Book" in key: # and "Inform" not in key: + resp = {x[0]: x[1] for x in sys_dialog_act[key]} + call = maybe_call + if call == self.last_call: + call = {} + resp = {} + if len(call) > 0: + self.last_call = call + return call, tod.TodStructuredRound( + user_utt=user_turn["utterance"], + api_call_machine=call, + api_resp_machine=resp, + sys_utt=sys_turn["utterance"], + ) + + def _get_schemas_for_goal_calls(self, goals): + result = [] + seen = set() + for goal in goals: + call_name = goal[tod.STANDARD_API_NAME_SLOT] + if call_name not in seen: + result.append(self.schemas[call_name]) + seen.add(call_name) + return result + + def setup_episodes(self, fold): + """ + Parses into TodStructuredEpisode. + """ + self.dbs = self.load_dbs() + self.schemas = self.load_schemas() + with PathManager.open(os.path.join(self.dpath, "dialog_acts.json")) as f: + self.dialog_acts = json.load(f) + + chunks = self.load_chunks(fold) + + episodes = [] + for raw_episode in chunks: + domains = raw_episode["services"] + + if self.opt.get("dialogue_id", "") != "": + if raw_episode["dialogue_id"] != self.opt["dialogue_id"]: + continue + + skip = False # need to skip outer for loop while in `for domains` inner for loop + if self.opt.get("well_formatted_domains_only", True): + if len(domains) == 0: + skip = True + for domain in domains: + if domain not in WELL_FORMATTED_DOMAINS: + skip = True + if skip: + continue + + turn_id = 0 # matching naming in the `dialogues` files. + turns = raw_episode["turns"] + rounds = [] + goal_calls = [] + + while turn_id < len(turns): + goal, r = self._get_round(raw_episode['dialogue_id'], turns, turn_id) + turn_id += 2 + rounds.append(r) + + if len(goal) > 0: + goal_calls.append(goal) + + episode = tod.TodStructuredEpisode( + domain=tod.SerializationHelpers.inner_list_join(domains), + api_schemas_machine=self._get_schemas_for_goal_calls(goal_calls), + goal_calls_machine=goal_calls, + rounds=rounds, + ) + episodes.append(episode) + return episodes + + def get_id_task_prefix(self): + return "MultiwozV22" + + +class UserSimulatorTeacher(MultiwozV22Parser, tod_agents.TodUserSimulatorTeacher): + pass + + +class SystemTeacher(MultiwozV22Parser, tod_agents.TodSystemTeacher): + pass + + +class DefaultTeacher(SystemTeacher): + pass diff --git a/parlai/tasks/multiwoz_v22/build.py b/parlai/tasks/multiwoz_v22/build.py new file mode 100644 index 00000000000..cdbe57c5635 --- /dev/null +++ b/parlai/tasks/multiwoz_v22/build.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import parlai.core.build_data as build_data +import os +from parlai.core.build_data import DownloadableFile + +# Pin against a specific commit since folks occasionally post fixes +MULTIWOZ_URL_BASE = "https://raw.githubusercontent.com/budzianowski/multiwoz/01e689362833ce33427a771a21cefe253e8f5886/" + +MULTIWOZ_22_URL_BASE = MULTIWOZ_URL_BASE + "data/MultiWOZ_2.2/" + +RESOURCES = [ + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'dialog_acts.json', + 'dialog_acts.json', + '328f392165e7826db9f827731b14b5cc04e79e9e3c6332bfb192a1ea17f8e9b6', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'schema.json', + 'schema.json', + 'ae9e2390f38fb967af64623c2f4f7e0c636fb377ad523b582a03161d3ddbdf68', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'dev/dialogues_001.json', + 'dev/dialogues_001.json', + 'e7ddb563e4da5766ea820cc826dead77e7ca219c19b761e218d62d9c999a252e', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'dev/dialogues_002.json', + 'dev/dialogues_002.json', + 'ede6a2c17fd6c5846214b8cabc1ef8f7cc8be01cfbacaa162bcafec9e87724e9', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'test/dialogues_001.json', + 'test/dialogues_001.json', + 'd6f43876cf130fdb2dfa8f96bc056b0995354137f02122e004925d01264ed386', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'test/dialogues_002.json', + 'test/dialogues_002.json', + '89af95d8f596a448e733d59b31be78f1dd1632eddd99d5cb298a3fcb1ac9d185', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_001.json', + 'train/dialogues_001.json', + '895a8109bf01fa5ecf15ccdbd2dfe1628bd923f6b61dcd2e26b10ee5076a1596', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_002.json', + 'train/dialogues_002.json', + '2f3ea771d4e01cb2780357738cff7f7496b87d34c221cc240df74501312438d3', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_003.json', + 'train/dialogues_003.json', + 'da24961d28486be2d8462ee4d86a809db819d588ba90ae1a783383d95eb85daa', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_004.json', + 'train/dialogues_004.json', + '30c1172db1071c853b16215d1946de908d68d2b6ff8da7801de307000a179106', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_005.json', + 'train/dialogues_005.json', + 'eaf58716df5de99524b3e0e7edf477b74749512788a6a51f53f2bdd76768d39a', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_006.json', + 'train/dialogues_006.json', + '8e75fd543b1964bc5e7118085d977f479c98fcdf6d606b971f67a45fb1745c83', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_007.json', + 'train/dialogues_007.json', + '02323f8298439d713c6d7d226f4bd7ec246ec993ee11911b54f98cb8a598f206', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_008.json', + 'train/dialogues_008.json', + '1129fbed480352ae304f0ae5b4915c194e9619c43f1577ccb0d450e10d24ea98', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_009.json', + 'train/dialogues_009.json', + '87d9e43b1ba51a4a4688703da79d3a09b14d8013570545da24c597daa18e2f45', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_010.json', + 'train/dialogues_010.json', + 'e7ad0d5da2909b08197295e45fe4695b9dc2f67d458374b2aab8db5094e97b26', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_011.json', + 'train/dialogues_011.json', + '82e2d2900a037b866a01d05974734dd419e89b329fe29ef93b35eea96d27feb8', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_012.json', + 'train/dialogues_012.json', + 'b6bf292325db67682dd7b6fafbf1051cc2262e92f7c37cab213e975888594bb2', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_013.json', + 'train/dialogues_013.json', + 'c33fe4b3952c016e1e1645f472a7097f93dfb476a19940fd9386865ef9adf685', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_014.json', + 'train/dialogues_014.json', + 'ce33dbbf93a40d0dcc671a9d6a2ed1f987914f5b5f05f6df493a5b342efda954', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_015.json', + 'train/dialogues_015.json', + 'd895c0439bc2ad89ef1689896e3be630eadebc33dae77b42a426fd16a271718e', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_016.json', + 'train/dialogues_016.json', + '3e6bc0bca4262022ccbce0d5ce3150e536e7d21aeb9fbdef9e83cced4dfd124b', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_22_URL_BASE + 'train/dialogues_017.json', + 'train/dialogues_017.json', + 'b6ab2cd9b6e8983364526845b7cbec1749338209bf7ac2313c25e1e2226ebab5', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/attraction_db.json', + 'db/attraction_db.json', + '2aacc620af4025f1eada5ec83057a866f8e8b72b529f71d2f8bf93bcdd8f8751', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/bus_db.json', + 'db/bus_db.json', + '4818e735bae20690f6d0d06bb2ae8eec1981c0b680258a970dc01c9073f3fec9', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/hospital_db.json', + 'db/hospital_db.json', + 'f28738bda15e750be912d653c5e68b06af41dba68d9cfa3febfdcfe972e14366', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/hotel_db.json', + 'db/hotel_db.json', + '972bbd65beada7c64f0b87322c694fd9173b46cf8e61ca3bbe951717ac1d1662', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/police_db.json', + 'db/police_db.json', + 'd9c2b200fa2dd61b04ce2fe520b0b79dfa68d8b895806cfec6e8a8d3cffa9193', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/restaurant_db.json', + 'db/restaurant_db.json', + '7864b4e36027af0907d6afe9ed962fecd6c6966cd626b8ae9341708a04ea201a', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/taxi_db.json', + 'db/taxi_db.json', + '08b8fb2436abec6d1fe9087054f943d2b31e1c4adc74d8202e4e2149a5630be3', + zipped=False, + ), + DownloadableFile( + MULTIWOZ_URL_BASE + 'db/train_db.json', + 'db/train_db.json', + '4818e735bae20690f6d0d06bb2ae8eec1981c0b680258a970dc01c9073f3fec9', + zipped=False, + ), +] + + +def build(opt): + dpath = os.path.join(opt['datapath'], 'multiwoz_v22') + version = '1.0' + + if not build_data.built(dpath, version_string=version): + print('[building data: ' + dpath + ']') + if build_data.built(dpath): + build_data.remove_dir(dpath) + build_data.make_dir(dpath) + + build_data.make_dir(dpath + "/dev") + build_data.make_dir(dpath + "/test") + build_data.make_dir(dpath + "/train") + + build_data.make_dir(dpath + "/db") + + # Download the data. + for downloadable_file in RESOURCES: + downloadable_file.download_file(dpath) + + build_data.mark_done(dpath, version_string=version) diff --git a/parlai/tasks/multiwoz_v22/build_sha_check_script.py b/parlai/tasks/multiwoz_v22/build_sha_check_script.py new file mode 100644 index 00000000000..51610689156 --- /dev/null +++ b/parlai/tasks/multiwoz_v22/build_sha_check_script.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Included as a convenience in case the files get updated. +""" + + +import hashlib +import wget + +MULTIWOZ_URL_BASE = "https://raw.githubusercontent.com/budzianowski/multiwoz/01e689362833ce33427a771a21cefe253e8f5886/" +MULTIWOZ_22_URL_BASE = MULTIWOZ_URL_BASE + "/data/MultiWOZ_2.2/" + +WANT = [ + "dialog_acts.json", + "schema.json", + "dev/dialogues_001.json", + "dev/dialogues_002.json", + "test/dialogues_001.json", + "test/dialogues_002.json", + "train/dialogues_001.json", + "train/dialogues_002.json", + "train/dialogues_003.json", + "train/dialogues_004.json", + "train/dialogues_005.json", + "train/dialogues_006.json", + "train/dialogues_007.json", + "train/dialogues_008.json", + "train/dialogues_009.json", + "train/dialogues_010.json", + "train/dialogues_011.json", + "train/dialogues_012.json", + "train/dialogues_013.json", + "train/dialogues_014.json", + "train/dialogues_015.json", + "train/dialogues_016.json", + "train/dialogues_017.json", +] + +FILES = [(x, MULTIWOZ_22_URL_BASE + x, "MULTIWOZ_22_URL_BASE") for x in WANT] + +WANT = [ + "db/attraction_db.json", + "db/bus_db.json", + "db/hospital_db.json", + "db/hotel_db.json", + "db/police_db.json", + "db/restaurant_db.json", + "db/taxi_db.json", + "db/train_db.json", +] + +FILES += [(x, MULTIWOZ_URL_BASE + x, "MULTIWOZ_URL_BASE") for x in WANT] + + +def checksum(dpath): + """ + Checksum on a given file. + + :param dpath: path to the downloaded file. + """ + sha256_hash = hashlib.sha256() + with open(dpath, "rb") as f: + for byte_block in iter(lambda: f.read(65536), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + +for f in FILES: + name, path, start = f + print(" DownloadableFile(") + print(f" {start} + '{name}',") + print(f" '{name}',") + filename = wget.download(path, bar=None) + print(f" '{checksum(filename)}',") + print(" zipped = False,") + print(" ),") diff --git a/parlai/tasks/multiwoz_v22/test.py b/parlai/tasks/multiwoz_v22/test.py new file mode 100644 index 00000000000..9b34a8a5391 --- /dev/null +++ b/parlai/tasks/multiwoz_v22/test.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from parlai.utils.testing import AutoTeacherTest + + +class TestSystemTeacher(AutoTeacherTest): + task = "multiwoz_v22" + + +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "multiwoz_v22:UserSimulatorTeacher" diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..29fee769b4e --- /dev/null +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml @@ -0,0 +1,50 @@ +acts: +- - domain: attraction, train + episode_done: false + eval_labels: + - 'USER: I need train reservations from norwich to cambridge' + id: MultiwozV22_UserSimulatorTeacher + text: 'GOAL: api_name = find_train ; departure = norwich ; destination = cambridge + | api_name = find_train ; arriveby = 18:00 ; day = monday ; departure = norwich + ; destination = cambridge | api_name = find_attraction ; name = cineworld cinema' + type: 'USER: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'USER: I''d like to leave on Monday and arrive by 18:00.' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: I have 133 trains matching your request. Is there a specific day + and time you would like to travel?' + type: 'USER: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'USER: Before booking, I would also like to know the travel time, price, and + departure time please.' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: There are 12 trains for the day and time you request. Would you + like to book it now?' + type: 'USER: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'USER: No hold off on booking for now. Can you help me find an attraction called + cineworld cinema?' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: There are 12 trains meeting your needs with the first leaving at + 05:16 and the last one leaving at 16:16. Do you want to book one of these?' + type: 'USER: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'USER: Yes, that was all I needed. Thank you very much!' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Yes it is a cinema located in the south part of town what information + would you like on it?' + type: 'USER: ' +num_episodes: 1000 +num_examples: 17744 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..2aca735580f --- /dev/null +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml @@ -0,0 +1,51 @@ +acts: +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_UserSimulatorTeacher + labels: + - 'USER: i need a place to dine in the center thats expensive' + text: 'GOAL: api_name = find_hotel | api_name = find_restaurant ; area = centre + ; pricerange = expensive | api_name = find_hotel ; pricerange = expensive ; + type = hotel | api_name = book_hotel ; bookday = saturday ; bookpeople = 2 ; + bookstay = 2 ; name = university arms hotel ; pricerange = expensive ; type + = hotel' + type: 'USER: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_UserSimulatorTeacher + labels: + - 'USER: Any sort of food would be fine, as long as it is a bit expensive. Could + I get the phone number for your recommendation?' + slots: {} + text: 'SYSTEM: I have several options for you; do you prefer African, Asian, or + British food?' + type: 'USER: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_UserSimulatorTeacher + labels: + - 'USER: Sounds good, could I get that phone number? Also, could you recommend + me an expensive hotel?' + slots: {} + text: 'SYSTEM: There is an Afrian place named Bedouin in the centre. How does + that sound?' + type: 'USER: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_UserSimulatorTeacher + labels: + - 'USER: Yes. Can you book it for me?' + slots: {} + text: 'SYSTEM: Bedouin''s phone is 01223367660. As far as hotels go, I recommend + the University Arms Hotel in the center of town.' + type: 'USER: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_UserSimulatorTeacher + labels: + - 'USER: i want to book it for 2 people and 2 nights starting from saturday.' + slots: {} + text: 'SYSTEM: Sure, when would you like that reservation?' + type: 'USER: ' +num_episodes: 7913 +num_examples: 133963 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..802f5e89ebb --- /dev/null +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml @@ -0,0 +1,50 @@ +acts: +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'USER: I''m looking for a local place to dine in the centre that serves chinese + food.' + id: MultiwozV22_UserSimulatorTeacher + text: 'GOAL: api_name = find_restaurant ; area = centre ; food = chinese | api_name + = find_train ; day = sunday ; departure = cambridge ; destination = norwich + ; leaveat = 16:15 | api_name = book_train ; bookpeople = 5 ; day = sunday ; + departure = cambridge ; destination = norwich ; leaveat = 16:15' + type: 'USER: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'USER: I need the address, postcode and the price range.' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: I have restaurants matching your criteria in all price ranges. + Do you have a preference on price?' + type: 'USER: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'USER: I also need a train. The train should leave after 16:15 and should leave + on sunday.' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Ok how about Charlie Chan, located at Regent Street City Centre. + Postcode is cb21db with a cheap price. Can I help you further today?' + type: 'USER: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'USER: I am leaving from Cambridge and going to Norwich.' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Can I have more information for the train you''re needing? Where + are you departing from and arriving to?' + type: 'USER: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'USER: book for 5 people and get me the reference number' + id: MultiwozV22_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: I have train TR1840 leaving at 16:36 is that okay?' + type: 'USER: ' +num_episodes: 999 +num_examples: 17731 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml new file mode 100644 index 00000000000..09c0e97c40d --- /dev/null +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml @@ -0,0 +1,66 @@ +acts: +- - domain: attraction, train + episode_done: false + eval_labels: + - 'APIS: ' + id: MultiwozV22_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'APICALL: api_name = find_train ; departure = norwich ; destination = cambridge' + id: MultiwozV22_SystemTeacher + slots: + api_name: find_train + departure: norwich + destination: cambridge + text: 'USER: I need train reservations from norwich to cambridge' + type: 'APICALL: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'SYSTEM: I have 133 trains matching your request. Is there a specific day and + time you would like to travel?' + id: MultiwozV22_SystemTeacher + slots: + COUNT: 133 + OPTIONS: '[{"arriveby": "06:35", "day": "monday", "departure": "norwich", "destination": + "cambridge", "duration": "79 minutes", "leaveat": "05:16", "price": "17.60 + pounds", "trainid": "TR9020"}]' + text: 'APIRESP: COUNT = 133 ; OPTIONS = [{"arriveby": "06:35", "day": "monday", + "departure": "norwich", "destination": "cambridge", "duration": "79 minutes", + "leaveat": "05:16", "price": "17.60 pounds", "trainid": "TR9020"}]' + type: 'SYSTEM: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'APICALL: api_name = find_train ; arriveby = 18:00 ; day = monday ; departure + = norwich ; destination = cambridge' + id: MultiwozV22_SystemTeacher + slots: + api_name: find_train + arriveby: '18:00' + day: monday + departure: norwich + destination: cambridge + text: 'USER: I''d like to leave on Monday and arrive by 18:00.' + type: 'APICALL: ' +- - domain: attraction, train + episode_done: false + eval_labels: + - 'SYSTEM: There are 12 trains for the day and time you request. Would you like + to book it now?' + id: MultiwozV22_SystemTeacher + slots: + COUNT: 12 + OPTIONS: '[{"arriveby": "06:35", "day": "monday", "departure": "norwich", "destination": + "cambridge", "duration": "79 minutes", "leaveat": "05:16", "price": "17.60 + pounds", "trainid": "TR9020"}]' + text: 'APIRESP: COUNT = 12 ; OPTIONS = [{"arriveby": "06:35", "day": "monday", + "departure": "norwich", "destination": "cambridge", "duration": "79 minutes", + "leaveat": "05:16", "price": "17.60 pounds", "trainid": "TR9020"}]' + type: 'SYSTEM: ' +num_episodes: 1000 +num_examples: 17744 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml new file mode 100644 index 00000000000..4b0217a42bd --- /dev/null +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml @@ -0,0 +1,73 @@ +acts: +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_SystemTeacher + labels: + - 'APICALL: api_name = find_hotel' + slots: + api_name: find_hotel + text: 'USER: i need a place to dine in the center thats expensive' + type: 'APICALL: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_SystemTeacher + labels: + - 'SYSTEM: I have several options for you; do you prefer African, Asian, or British + food?' + slots: + COUNT: 33 + OPTIONS: '[{"address": "124 tenison road", "area": "east", "internet": "yes", + "parking": "no", "id": "0", "location": [52.1963733, 0.1987426], "name": "a + and b guest house", "phone": "01223315702", "postcode": "cb12dp", "price": + {"double": "70", "family": "90", "single": "50"}, "pricerange": "moderate", + "stars": "4", "takesbookings": "yes", "type": "guesthouse", "n": NaN}]' + text: 'APIRESP: COUNT = 33 ; OPTIONS = [{"address": "124 tenison road", "area": + "east", "internet": "yes", "parking": "no", "id": "0", "location": [52.1963733, + 0.1987426], "name": "a and b guest house", "phone": "01223315702", "postcode": + "cb12dp", "price": {"double": "70", "family": "90", "single": "50"}, "pricerange": + "moderate", "stars": "4", "takesbookings": "yes", "type": "guesthouse", "n": + NaN}]' + type: 'SYSTEM: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_SystemTeacher + labels: + - 'APICALL: api_name = find_restaurant ; area = centre ; pricerange = expensive' + slots: + api_name: find_restaurant + area: centre + pricerange: expensive + text: 'USER: Any sort of food would be fine, as long as it is a bit expensive. + Could I get the phone number for your recommendation?' + type: 'APICALL: ' +- - domain: hotel, restaurant + episode_done: false + id: MultiwozV22_SystemTeacher + labels: + - 'SYSTEM: There is an Afrian place named Bedouin in the centre. How does that + sound?' + slots: + COUNT: 33 + OPTIONS: '[{"address": "106 Regent Street City Centre", "area": "centre", "food": + "indian", "id": "19214", "introduction": "curry garden serves traditional + indian and bangladeshi cuisine cooked with fresh produce delivered every day", + "location": [52.200187, 0.126407], "name": "curry garden", "phone": "01223302330", + "postcode": "cb21dp", "pricerange": "expensive", "type": "restaurant", "signature": + NaN}]' + text: 'APIRESP: COUNT = 33 ; OPTIONS = [{"address": "106 Regent Street City Centre", + "area": "centre", "food": "indian", "id": "19214", "introduction": "curry garden + serves traditional indian and bangladeshi cuisine cooked with fresh produce + delivered every day", "location": [52.200187, 0.126407], "name": "curry garden", + "phone": "01223302330", "postcode": "cb21dp", "pricerange": "expensive", "type": + "restaurant", "signature": NaN}]' + type: 'SYSTEM: ' +num_episodes: 7913 +num_examples: 133963 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml new file mode 100644 index 00000000000..5228dce4c2f --- /dev/null +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml @@ -0,0 +1,59 @@ +acts: +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'APIS: ' + id: MultiwozV22_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'APICALL: api_name = find_restaurant ; area = centre ; food = chinese' + id: MultiwozV22_SystemTeacher + slots: + api_name: find_restaurant + area: centre + food: chinese + text: 'USER: I''m looking for a local place to dine in the centre that serves + chinese food.' + type: 'APICALL: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'SYSTEM: I have restaurants matching your criteria in all price ranges. Do you + have a preference on price?' + id: MultiwozV22_SystemTeacher + slots: + COUNT: 10 + OPTIONS: '[{"address": "Regent Street City Centre", "area": "centre", "food": + "chinese", "id": "19212", "introduction": "charlie chan serves a variety of + popular chinese dishes at their restaurant", "location": [52.201743, 0.124843], + "name": "charlie chan", "phone": "01223361763", "postcode": "cb21db", "pricerange": + "cheap", "type": "restaurant", "signature": NaN}]' + text: 'APIRESP: COUNT = 10 ; OPTIONS = [{"address": "Regent Street City Centre", + "area": "centre", "food": "chinese", "id": "19212", "introduction": "charlie + chan serves a variety of popular chinese dishes at their restaurant", "location": + [52.201743, 0.124843], "name": "charlie chan", "phone": "01223361763", "postcode": + "cb21db", "pricerange": "cheap", "type": "restaurant", "signature": NaN}]' + type: 'SYSTEM: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'APICALL: ' + id: MultiwozV22_SystemTeacher + slots: {} + text: 'USER: I need the address, postcode and the price range.' + type: 'APICALL: ' +- - domain: restaurant, train + episode_done: false + eval_labels: + - 'SYSTEM: Ok how about Charlie Chan, located at Regent Street City Centre. Postcode + is cb21db with a cheap price. Can I help you further today?' + id: MultiwozV22_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 999 +num_examples: 17731 From c17c3ccb5f7254cedf8634ad3443b04cf65a93cd Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 13:57:36 -0800 Subject: [PATCH 18/57] lint --- parlai/tasks/msr_e2e/agents.py | 1 + 1 file changed, 1 insertion(+) diff --git a/parlai/tasks/msr_e2e/agents.py b/parlai/tasks/msr_e2e/agents.py index 95e2122b896..aa7b13eb04d 100644 --- a/parlai/tasks/msr_e2e/agents.py +++ b/parlai/tasks/msr_e2e/agents.py @@ -322,5 +322,6 @@ class SystemTeacher(MsrE2EParser, tod_agents.TodSystemTeacher): class UserSimulatorTeacher(MsrE2EParser, tod_agents.TodUserSimulatorTeacher): pass + class DefaultTeacher(SystemTeacher): pass From 728c9bd7d012a920481337a7ca8b9bf5ac3aa93e Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 14:01:12 -0800 Subject: [PATCH 19/57] [TOD][Datasets][Easy] Taskmaster(1) in Conversations format Title. I only include System + UserSimulator Teachers here since that's all we need right now from dataset. There's non-fb people that made edits in the original version of Taskmaster, so keep those teachers around too. --------------- Datasets added in this substack: * Google SGD * Google SGD Simulation Splits (In-domain, Out-domain) * MetalWoz * MSR_E2E * Multidogo * MultiWoz V2.2 * **Taskmaster** * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/taskmaster/agents.py | 248 +++++++++++++++++- parlai/tasks/taskmaster/test.py | 8 + .../test/taskmaster_SystemTeacher_test.yml | 45 ++++ .../test/taskmaster_SystemTeacher_train.yml | 49 ++++ .../test/taskmaster_SystemTeacher_valid.yml | 50 ++++ .../taskmaster_UserSimulatorTeacher_test.yml | 57 ++++ .../taskmaster_UserSimulatorTeacher_train.yml | 46 ++++ .../taskmaster_UserSimulatorTeacher_valid.yml | 46 ++++ 8 files changed, 547 insertions(+), 2 deletions(-) create mode 100644 parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml create mode 100644 parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml create mode 100644 parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml create mode 100644 parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/taskmaster/agents.py b/parlai/tasks/taskmaster/agents.py index 94792126ab7..237a9fd4a10 100644 --- a/parlai/tasks/taskmaster/agents.py +++ b/parlai/tasks/taskmaster/agents.py @@ -4,16 +4,260 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +""" +Taskmaster-1 implementation for ParlAI. + +Note that we have conversations structured both in the "TOD" format as well as those +from prior. +""" -from typing import Optional from parlai.core.params import ParlaiParser +import os +import pandas as pd from parlai.core.opt import Opt -from parlai.core.teachers import FixedDialogTeacher +import parlai.core.tod.tod_core as tod +from typing import Optional +from parlai.utils.data import DatatypeHelper from parlai.utils.io import PathManager + +import parlai.tasks.taskmaster.build as build_ +import parlai.core.tod.tod_agents as tod_agents + +# Following is for legacy format +from parlai.core.teachers import FixedDialogTeacher from . import tm_utils import json +################### TOD Conversation format + +SILENCE_TOKEN = "__SILENCE__" + +# Faster to copy/paste this than parse a json file +ONTOLOGY = { + "uber": { + "id": "uber_lyft", + "vertical": "ride_booking", + "required": ["location.from", "location.to", "type.ride", "num.people"], + "optional": [ + "price.estimate", + "duration.estimate", + "time.pickup", + "time.dropoff", + ], + }, + "movie": { + "id": "movie_ticket", + "vertical": "ticket_booking", + "required": [ + "name.movie", + "name.theater", + "num.tickets", + "time.start", + "location.theater", + "price.ticket", + ], + "optional": ["type.screening", "time.end", "time.duration"], + }, + "restaurant": { + "id": "restaurant_reservation", + "vertical": "reservation", + "required": [ + "name.restaurant", + "name.reservation", + "num.guests", + "time.reservation", + ], + "optional": ["type.seating", "location.restaurant"], + }, + "coffee": { + "id": "coffee_ordering", + "vertical": "coffee_order", + "required": ["location.store", "name.drink", "size.drink"], + "optional": ["num.drink", "type.milk", "preference"], + }, + "pizza": { + "id": "pizza_ordering", + "vertical": "pizza_order", + "required": ["name.store", "name.pizza", "size.pizza"], + "optional": ["type.topping", "type.crust", "preference", "location.store"], + }, + "auto": { + "id": "auto_repair", + "vertical": "appointment", + "required": ["name.store", "name.customer", "date.appt", "time.appt"], + "optional": ["reason.appt", "name.vehicle", "year.vehicle", "location.store"], + }, +} + + +class Taskmaster1Parser(tod_agents.TodStructuredDataParser): + """ + Abstract data loader. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + return parser + + def __init__(self, opt: Opt, shared=None): + self.fold = DatatypeHelper.fold(opt["datatype"]) + opt["datafile"] = self.fold + self.dpath = os.path.join(opt["datapath"], "taskmaster-1") + if shared is None: + build_.build(opt) + super().__init__(opt, shared) + + def _load_data(self, fold): + chunks = [] + with PathManager.open(os.path.join(self.dpath, f"self-dialogs.json")) as f: + subset = pd.read_json(f) + chunks.append(subset) + with PathManager.open(os.path.join(self.dpath, f"woz-dialogs.json")) as f: + subset = pd.read_json(f) + chunks.append(subset) + chunks = pd.concat(chunks, axis=0) + # deterministic shuffle data for splits + chunks = chunks.sample(frac=1.0, random_state=42) + split_size = len(chunks) // 10 + if fold == "train": + chunks = chunks[: split_size * 8] + elif fold == "valid": + chunks = chunks[split_size * 8 : split_size * 9] + elif fold == "test": + chunks = chunks[split_size * 9 :] + return chunks, ONTOLOGY + + def _parse_segment_to_slots(self, segment_list): + result = {} + for segment in segment_list: + slot_name = segment["annotations"][0]["name"] + slot_value = segment["text"] + prefix_split_idx = slot_name.find(".") + api_name = slot_name[:prefix_split_idx] + slot_name = slot_name[prefix_split_idx + 1 :] + result[slot_name] = slot_value + result[tod.STANDARD_API_NAME_SLOT] = api_name + return result + + def _get_utterance_and_slots_for_speaker(self, speaker, utterances, idx): + utts = [] + slots = {} + while idx < len(utterances): + here = utterances[idx] + if here["speaker"] != speaker: + break + utts.append(here["text"]) + slots.update(self._parse_segment_to_slots(here.get("segments", []))) + idx += 1 + return idx, "\n".join(utts), slots + + def _parse_to_api_schema(self, raw): + """ + NOTE: Format of ontology in this is different from TM2 + TM3. Need to figure out which is relevant for the domain. + """ + result = {} + for key, val in raw.items(): + here = {} + here[tod.STANDARD_API_NAME_SLOT] = val["id"] + here[tod.STANDARD_REQUIRED_KEY] = val.get("required", []) + here[tod.STANDARD_OPTIONAL_KEY] = val.get("optional", []) + result[key] = here + return result + + def _get_turns_from_parsed(self, user_utt, api_calls, api_resps, sys_utt): + result = [ + tod.TodStructuredRound( + user_utt=user_utt, + api_call_machine=api_calls, + api_resp_machine=api_resps, + sys_utt=sys_utt, + ) + ] + return result + + def setup_episodes(self, fold): + """ + Parses into TodStructuredEpisode. + """ + chunks, api_schema_raw = self._load_data(fold) + api_schemas_machine = self._parse_to_api_schema(api_schema_raw) + episodes = [] + for _, row in chunks.iterrows(): + utterances = row["utterances"][:] + if not all( + [ + x.get("speaker") == "ASSISTANT" or x.get("speaker") == "USER" + for x in utterances + ] + ): + # there's an example or two that causes things to infinite loop. >.> + continue + idx = 0 + rounds = [] + goal_calls = [] + if len(utterances) > 0 and utterances[0]["speaker"] == "ASSISTANT": + (idx, sys_utt, _,) = self._get_utterance_and_slots_for_speaker( + "ASSISTANT", utterances, idx + ) + + turns = self._get_turns_from_parsed(SILENCE_TOKEN, {}, {}, sys_utt) + for t in turns: + rounds.append(t) + + while idx < len(utterances): + ( + idx, + user_utt, + user_slots, + ) = self._get_utterance_and_slots_for_speaker("USER", utterances, idx) + ( + idx, + sys_utt, + system_slots, + ) = self._get_utterance_and_slots_for_speaker( + "ASSISTANT", utterances, idx + ) + # The annotations in this dataset don't make sense as api responses but... we'll just roll. + turns = self._get_turns_from_parsed( + user_utt, user_slots, system_slots, sys_utt + ) + for t in turns: + rounds.append(t) + apis = [] + for candidate_api in api_schemas_machine: + if candidate_api in row["instruction_id"]: + apis.append(api_schemas_machine[candidate_api]) + episode = tod.TodStructuredEpisode( + api_schemas_machine=apis, + goal_calls_machine=goal_calls, + rounds=rounds, + delex=self.opt.get("delex", False), + ) + episodes.append(episode) + return episodes + + def get_id_task_prefix(self): + return "Taskmaster1" + + def _label_fold(self, chunks): + return chunks.conversation_id.apply(self._h) + + +class SystemTeacher(Taskmaster1Parser, tod_agents.TodSystemTeacher): + pass + + +class UserSimulatorTeacher(Taskmaster1Parser, tod_agents.TodUserSimulatorTeacher): + pass + + +############ Legacy defined teachers + + class SelfDialogueTeacher(FixedDialogTeacher): """ Teacher for written two-person dialogues with labels being responses for the diff --git a/parlai/tasks/taskmaster/test.py b/parlai/tasks/taskmaster/test.py index d757ec62575..42278ee331a 100644 --- a/parlai/tasks/taskmaster/test.py +++ b/parlai/tasks/taskmaster/test.py @@ -21,3 +21,11 @@ class TestWozDialogueTeacher(AutoTeacherTest): class TestSelfDialogueSegmentTeacher(AutoTeacherTest): task = "taskmaster:self_dialogue_segment" + + +class TestSystemTeacher(AutoTeacherTest): + task = "taskmaster:SystemTeacher" + + +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "taskmaster:UserSimulatorTeacher" diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml new file mode 100644 index 00000000000..998350f1020 --- /dev/null +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml @@ -0,0 +1,45 @@ +acts: +- - domain: '' + episode_done: false + eval_labels: + - 'APIS: ' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: '' + episode_done: false + eval_labels: + - 'APICALL: ' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'USER: __SILENCE__' + type: 'APICALL: ' +- - domain: '' + episode_done: false + eval_labels: + - 'SYSTEM: hey there, how can i help you?' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: '' + episode_done: false + eval_labels: + - 'APICALL: ' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'USER: Hi. I want to ride Uber in car.' + type: 'APICALL: ' +- - domain: '' + episode_done: false + eval_labels: + - 'SYSTEM: sure, where are you heading to? + + where is your destination?' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 1326 +num_examples: 33198 diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml new file mode 100644 index 00000000000..417c226d0e4 --- /dev/null +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml @@ -0,0 +1,49 @@ +acts: +- - domain: '' + episode_done: false + id: Taskmaster1_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: '' + episode_done: false + id: Taskmaster1_SystemTeacher + labels: + - 'APICALL: ' + slots: {} + text: 'USER: __SILENCE__' + type: 'APICALL: ' +- - domain: '' + episode_done: false + id: Taskmaster1_SystemTeacher + labels: + - 'SYSTEM: hi, how can i assist you?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: '' + episode_done: false + id: Taskmaster1_SystemTeacher + labels: + - 'APICALL: api_name = uber_lyft ; location.from.accept = the Wichita Dwight D. + Eisenhower National Airport ; location.to.accept = the Wichita State University + campus' + slots: + api_name: uber_lyft + location.from.accept: the Wichita Dwight D. Eisenhower National Airport + location.to.accept: the Wichita State University campus + text: 'USER: Hi, I need an Uber pickup from the Wichita Dwight D. Eisenhower National + Airport to the Wichita State University campus.' + type: 'APICALL: ' +- - domain: '' + episode_done: false + id: Taskmaster1_SystemTeacher + labels: + - 'SYSTEM: what kind of ride would you like?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 10555 +num_examples: 262021 diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml new file mode 100644 index 00000000000..a541abf7d1f --- /dev/null +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml @@ -0,0 +1,50 @@ +acts: +- - domain: '' + episode_done: false + eval_labels: + - 'APIS: ' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: '' + episode_done: false + eval_labels: + - 'APICALL: ' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'USER: __SILENCE__' + type: 'APICALL: ' +- - domain: '' + episode_done: false + eval_labels: + - 'SYSTEM: hello, how can i assist you?' + id: Taskmaster1_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: '' + episode_done: false + eval_labels: + - 'APICALL: api_name = coffee_ordering ; location.store = the Hilton Knoxville + Starbucks in Knoxville, Tennessee ; name.drink = cinnamon shortbread latte' + id: Taskmaster1_SystemTeacher + slots: + api_name: coffee_ordering + location.store: the Hilton Knoxville Starbucks in Knoxville, Tennessee + name.drink: cinnamon shortbread latte + text: 'USER: Hi. I''m trying to order a cinnamon shortbread latte from the Hilton + Knoxville Starbucks in Knoxville, Tennessee.' + type: 'APICALL: ' +- - domain: '' + episode_done: false + eval_labels: + - 'SYSTEM: did you say hilton knoxville?' + id: Taskmaster1_SystemTeacher + slots: + api_name: coffee_ordering + location.store: hilton knoxville + text: 'APIRESP: api_name = coffee_ordering ; location.store = hilton knoxville' + type: 'SYSTEM: ' +num_episodes: 1321 +num_examples: 32641 diff --git a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..fcafb66b9f8 --- /dev/null +++ b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml @@ -0,0 +1,57 @@ +acts: +- - domain: '' + episode_done: false + eval_labels: + - 'USER: __SILENCE__' + id: Taskmaster1_UserSimulatorTeacher + text: 'GOAL: ' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: Hi. I want to ride Uber in car.' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: hey there, how can i help you?' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: Mohit Tetar.' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: sure, where are you heading to? + + where is your destination?' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: Yeah. Got it.' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: and where would you like to be picked up? + + i''m sorry, i didn''t understand you. + + where would you like to be picked up? + + ok, so you''d like to go from a restaurant to a movie theater, is that correct?' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: Yep, today at 9:00 p.m.' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: How many people will you need a ride for? + + and what type of ride would you like? + + an uber ride for two people, shared ride, from a restaurant to a movie theater + will cost $7.94 and take about 8 minutes. + + would you like to book this right now?' + type: 'USER: ' +num_episodes: 1326 +num_examples: 33198 diff --git a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..fc424778230 --- /dev/null +++ b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml @@ -0,0 +1,46 @@ +acts: +- - domain: '' + episode_done: false + id: Taskmaster1_UserSimulatorTeacher + labels: + - 'USER: __SILENCE__' + text: 'GOAL: ' + type: 'USER: ' +- - domain: '' + episode_done: false + id: Taskmaster1_UserSimulatorTeacher + labels: + - 'USER: Hi, I need an Uber pickup from the Wichita Dwight D. Eisenhower National + Airport to the Wichita State University campus.' + slots: {} + text: 'SYSTEM: hi, how can i assist you?' + type: 'USER: ' +- - domain: '' + episode_done: false + id: Taskmaster1_UserSimulatorTeacher + labels: + - 'USER: I just need if it if it''s I have four people but if it''s possible to + get a shared ride, I I''d like that.' + slots: {} + text: 'SYSTEM: what kind of ride would you like?' + type: 'USER: ' +- - domain: '' + episode_done: false + id: Taskmaster1_UserSimulatorTeacher + labels: + - 'USER: Okay, then we''re going to have to go with that.' + slots: {} + text: 'SYSTEM: for a party of 4, you can only request uberx or uberxl.' + type: 'USER: ' +- - domain: '' + episode_done: false + id: Taskmaster1_UserSimulatorTeacher + labels: + - 'USER: Tonight at 9:00 p.m. We''ll be landing at about 8:30, so 9:00 p.m. should + be good.' + slots: {} + text: 'SYSTEM: your fare estimate is $22.69. what time do you want to be picked + up?' + type: 'USER: ' +num_episodes: 10555 +num_examples: 262021 diff --git a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..3b8f69cb4c2 --- /dev/null +++ b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml @@ -0,0 +1,46 @@ +acts: +- - domain: '' + episode_done: false + eval_labels: + - 'USER: __SILENCE__' + id: Taskmaster1_UserSimulatorTeacher + text: 'GOAL: ' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: Hi. I''m trying to order a cinnamon shortbread latte from the Hilton + Knoxville Starbucks in Knoxville, Tennessee.' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: hello, how can i assist you?' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: Yes, at The Hill Starbucks at The Hill, Knoxville.' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: did you say hilton knoxville?' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: I think just a tall with whole milk. + + And Can you have they do add a little bit more than normal amount of cinnamon + on top or the cinnamon and nutmeg that they put on top of it?' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: what size would you like your coffee?' + type: 'USER: ' +- - domain: '' + episode_done: false + eval_labels: + - 'USER: (s) Just a dab.' + id: Taskmaster1_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: sure. would you like whipped cream?' + type: 'USER: ' +num_episodes: 1321 +num_examples: 32641 From 5c226bdb0712aab476a0fef29327f9556cf34d47 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 14:05:46 -0800 Subject: [PATCH 20/57] [TOD][Datasets][Easy] Taskmaster2 to TOD Conversations format Title. I only include System + UserSimulator Teachers here since that's all we need right now from dataset. There's a legacy implementation of Taskmaster2 here, but doesn't seem to be anyone using, so clobbering. --------------------------------- Datasets added in this substack: * Google SGD * Google SGD Simulation Splits (In-domain, Out-domain) * MetalWoz * MSR_E2E * Multidogo * **MultiWoz V2.2** * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/taskmaster2/README.md | 2 +- parlai/tasks/taskmaster2/agents.py | 447 ++++++------------ parlai/tasks/taskmaster2/build.py | 90 ++-- parlai/tasks/taskmaster2/test.py | 15 + .../taskmaster2_UserSimulatorTeacher_test.yml | 46 ++ ...taskmaster2_UserSimulatorTeacher_train.yml | 43 ++ ...taskmaster2_UserSimulatorTeacher_valid.yml | 45 ++ .../taskmaster2/test/taskmaster2_test.yml | 55 +++ .../taskmaster2/test/taskmaster2_train.yml | 48 ++ .../taskmaster2/test/taskmaster2_valid.yml | 43 ++ 10 files changed, 494 insertions(+), 340 deletions(-) create mode 100644 parlai/tasks/taskmaster2/test.py create mode 100644 parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml create mode 100644 parlai/tasks/taskmaster2/test/taskmaster2_test.yml create mode 100644 parlai/tasks/taskmaster2/test/taskmaster2_train.yml create mode 100644 parlai/tasks/taskmaster2/test/taskmaster2_valid.yml diff --git a/parlai/tasks/taskmaster2/README.md b/parlai/tasks/taskmaster2/README.md index ac138930708..a64ce6fd321 100644 --- a/parlai/tasks/taskmaster2/README.md +++ b/parlai/tasks/taskmaster2/README.md @@ -1,5 +1,5 @@ # Taskmaster 2 Originally from the -[Google Research Datasets](https://github.com/google-research-datasets/Taskmaster/blob/main/TM-2-2020/README.md). +[Google Research Datasets](https://github.com/google-research-datasets/Taskmaster/blob/master/TM-2-2020/README.md). See that page for details. diff --git a/parlai/tasks/taskmaster2/agents.py b/parlai/tasks/taskmaster2/agents.py index ce761afc47c..f5cb8d81fa0 100644 --- a/parlai/tasks/taskmaster2/agents.py +++ b/parlai/tasks/taskmaster2/agents.py @@ -14,36 +14,31 @@ from parlai.core.params import ParlaiParser import os import pandas as pd -import hashlib from collections import Counter from parlai.core.opt import Opt -from parlai.core.teachers import DialogTeacher -from parlai.core.metrics import AverageMetric, F1Metric, BleuMetric +import parlai.core.tod.tod_core as tod from parlai.utils.misc import warn_once import json -import parlai.utils.logging as logging -from typing import Optional, Tuple -from parlai.core.message import Message +from typing import Optional +from parlai.utils.data import DatatypeHelper from parlai.utils.io import PathManager import parlai.tasks.taskmaster2.build as build_ +import parlai.core.tod.tod_agents as tod_agents + DOMAINS = [ - 'flights', - 'food-ordering', - 'hotels', - 'movies', - 'restaurant-search', - 'sports', - 'music', + "flights", + "food-ordering", + "hotels", + "movies", + "restaurant-search", + "sports", + "music", ] -ONTO_TOKEN = "Onto:" -CALL_TOKEN = "Call:" -RESP_TOKEN = "Result:" - -class _Abstract(DialogTeacher): +class Taskmaster2Parser(tod_agents.TodStructuredDataParser): """ Abstract data loader. """ @@ -52,21 +47,26 @@ class _Abstract(DialogTeacher): def add_cmdline_args( cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None ) -> ParlaiParser: - super().add_cmdline_args(parser, partial_opt) - parser.add_argument('--include-ontology', type=bool, default=False) parser.add_argument( - '--domains', - nargs='+', + "--taskmaster2-domains", + nargs="+", default=DOMAINS, choices=DOMAINS, - help='Uses last passed in configuration.', + help="Uses last passed in configuration.", + ) + parser.add_argument( + "--use-cumulative-api-calls", + type=bool, + default=True, + help="Have API Call/API response turns only when an API response" + "slot exist. Accumulate all API call slots with same API call name", ) - return parser + return super().add_cmdline_args(parser, partial_opt) def __init__(self, opt: Opt, shared=None): - self.fold = opt['datatype'].split(':')[0] - opt['datafile'] = self.fold - self.dpath = os.path.join(opt['datapath'], 'taskmaster-2') + self.fold = DatatypeHelper.fold(opt["datatype"]) + opt["datafile"] = self.fold + self.dpath = os.path.join(opt["datapath"], "taskmaster-2") if shared is None: warn_once( "Taskmaster2 is a beta dataset, and format may significantly change." @@ -74,298 +74,157 @@ def __init__(self, opt: Opt, shared=None): build_.build(opt) super().__init__(opt, shared) - def _h(self, x): - """ - Hash function. - """ - h = int(hashlib.sha1(x.encode('utf-8')).hexdigest(), 16) % 10 - if h == 0: - return 'valid' - elif h == 1: - return 'test' - else: - return 'train' - - def _normalize_annotation(self, anno): - return anno - def _load_data(self, fold, domains): # load up the ontology - ontology = {} + ontologies = {} for section in domains: - parts = [] - fn = os.path.join(self.dpath, section + '.onto.json') - with PathManager.open(fn, 'r') as f: - o = json.load(f) - assert len(o) == 1 - o = list(o.values())[0] - for sub in o: - prefix = sub['prefix'] - parts += [ - self._normalize_annotation(f'{prefix}.{a}') - for a in sub['annotations'] - ] - ontology[section] = ' ; '.join(parts) + fn = os.path.join(self.dpath, section + ".onto.json") + with PathManager.open(fn, "r") as f: + ontologies.update(json.load(f)) chunks = [] for section in domains: - with PathManager.open(os.path.join(self.dpath, section + '.json')) as f: + with PathManager.open(os.path.join(self.dpath, section + ".json")) as f: subset = pd.read_json(f) - subset['domain'] = section + subset["domain"] = section chunks.append(subset) chunks = pd.concat(chunks, axis=0) - # shuffle deterministically for randomness in few-shot training + # deterministic shuffle data for splits chunks = chunks.sample(frac=1.0, random_state=42) - chunks['fold'] = self._label_fold(chunks) - # only the fold we need here - chunks = chunks[chunks.fold == fold].reset_index() - chunks['ontology'] = chunks['domain'].apply(ontology.get) - return chunks - - def _segments2text(self, segments): - output = [] + split_size = len(chunks) // 10 + if fold == "train": + chunks = chunks[: split_size * 8] + elif fold == "valid": + chunks = chunks[split_size * 8 : split_size * 9] + elif fold == "test": + chunks = chunks[split_size * 9 :] + return chunks, ontologies + + def _parse_segment_to_slots(self, segment_list): + result = {} + for segment in segment_list: + slot_name = segment["annotations"][0]["name"] + slot_value = segment["text"] + prefix_split_idx = slot_name.find(".") + api_name = slot_name[:prefix_split_idx] + slot_name = slot_name[prefix_split_idx + 1 :] + result[slot_name] = slot_value + result[tod.STANDARD_API_NAME_SLOT] = api_name + return result + + def _get_utterance_and_api_call_for_speaker(self, speaker, utterances, idx): + utts = [] slots = {} - for segment in segments: - val = segment['text'] - for anno_ in segment['annotations']: - anno = anno_['name'] - anno = self._normalize_annotation(anno) - output.append(f'{anno} = {val}') - slots[anno] = val - return " ; ".join(output), slots - - def custom_evaluation( - self, - teacher_action: Message, - labels: Optional[Tuple[str]], - model_response: Message, - ): - if 'metrics' in model_response and 'type' in teacher_action: - # keep copies of metrics across both api calls/responses - prefix = teacher_action['type'] - keys = list(model_response['metrics'].keys()) - for k in keys: - self.metrics.add(f'{prefix}_{k}', model_response['metrics'][k]) - - if 'text' not in model_response or not labels or 'type' not in teacher_action: - return - - domain = teacher_action['domain'] - - if teacher_action['type'] == 'apicall': - # also count slot accuracy - text = model_response['text'] - slot_guesses = set( - text.replace(CALL_TOKEN + " ", "").split(' ; ') - ) # prevent cheating via repeated guesses - correct = 0 - for slot_guess in slot_guesses: - if ' = ' not in slot_guess: - continue - try: - slot, guess = slot_guess.split(' = ') - except ValueError: - continue - if teacher_action['slots'].get(slot) == guess: - self.metrics.add('slot_p', AverageMetric(1)) - self.metrics.add(f'{domain}_slot_p', AverageMetric(1)) - correct += 1 - else: - self.metrics.add('slot_p', AverageMetric(0)) - self.metrics.add(f'{domain}_slot_p', AverageMetric(0)) - logging.debug( - f"Bad slot guess '{slot_guess}' != {teacher_action['slots']}" - ) - if teacher_action['slots']: - self.metrics.add( - 'slot_r', AverageMetric(correct, len(teacher_action['slots'])) - ) - self.metrics.add( - f'{domain}_slot_r', - AverageMetric(correct, len(teacher_action['slots'])), - ) - self.metrics.add( - 'jga', AverageMetric(correct == len(teacher_action['slots'])) - ) - - elif teacher_action['type'] == 'apiresp': - # keep track of statistics by domain - f1_metric = F1Metric.compute(model_response['text'], labels) - bleu_metric = BleuMetric.compute(model_response['text'], labels) - self.metrics.add(f'{domain}_lex_f1', f1_metric) - self.metrics.add(f'{domain}_lex_bleu', bleu_metric) - - delex_text = model_response['text'] - delex_label = labels[0] - # compute delexicalized string metrics - for slot, value in teacher_action['slots'].items(): - delex_text = delex_text.replace(value, slot) - delex_label = delex_label.replace(value, slot) - f1_metric = F1Metric.compute(delex_text, (delex_label,)) - self.metrics.add('delex_f1', f1_metric) - self.metrics.add(f'{domain}_delex_f1', f1_metric) - bleu_metric = BleuMetric.compute(delex_text, [delex_label]) - self.metrics.add('delex_bleu', bleu_metric) - self.metrics.add(f'{domain}_delex_bleu', bleu_metric) - - def setup_data(self, fold): - domains = self.opt.get('domains', DOMAINS) - chunks = self._load_data(fold, domains) - domains_cnt = Counter() - for _, row in chunks.iterrows(): - domains_cnt[row['domain']] += 1 - first = True - utterances = row['utterances'][:] - if ( - len(utterances) >= 3 - and utterances[0]['speaker'] == 'USER' - and utterances[1]['speaker'] == 'ASSISTANT' - and utterances[2]['speaker'] == 'ASSISTANT' - and "help you?" in utterances[1]['text'] - ): - # skip this one - utterances.pop(1) - if self.opt['include_ontology']: - yield {'text': f"{ONTO_TOKEN} {row['ontology']}", 'label': ''}, True - first = False - while utterances: - utt = utterances.pop(0) - segtxt, slots = self._segments2text(utt.get('segments', [])) - if utt['speaker'] == 'USER': - yield { - 'text': utt['text'], - 'label': f'{CALL_TOKEN} {segtxt}', - 'domain': row['domain'], - 'slots': slots, - 'type': 'apicall', - }, first - first = False - elif utt['speaker'] == 'ASSISTANT': - yield { - 'text': f'{RESP_TOKEN} {segtxt}', - 'label': utt['text'], - 'domain': row['domain'], - 'slots': slots, - 'type': 'apiresp', - }, first - first = False - logging.debug(f"Fold {fold} domains: {domains_cnt}") - - -class DelexTeacher(_Abstract): - def _label_fold(self, chunks): - return chunks.conversation_id.apply(self._h) - - def _delexicalize(self, text, slots): - for key, value in slots.items(): - text = text.replace(value, key) - return text - - def setup_data(self, fold): + while idx < len(utterances): + here = utterances[idx] + if here["speaker"] != speaker: + break + utts.append(here["text"]) + slots.update(self._parse_segment_to_slots(here.get("segments", []))) + idx += 1 + return idx, "\n".join(utts), slots + + def _get_onto_list(self, onto_map, domain): + results = [] + domain = domain.replace( + "-", "_" + ) # cause they changed it for restaurant-search >.> + for data in onto_map[domain]: + call = {} + call[tod.STANDARD_API_NAME_SLOT] = data["prefix"] + call[tod.STANDARD_OPTIONAL_KEY] = data[ + "annotations" + ] # make all args optional since not specified + results.append(call) + return results + + def setup_episodes(self, fold): + """ + Parses into TodStructuredEpisode. + """ + domains = self.opt.get("taskmaster2_domains", DOMAINS) + chunks, ontologies = self._load_data(fold, domains) domains_cnt = Counter() - chunks = self._load_data(fold) + episodes = [] for _, row in chunks.iterrows(): - domains_cnt[row['domain']] += 1 - first = True - utterances = row['utterances'][:] - if ( - len(utterances) >= 3 - and utterances[0]['speaker'] == 'USER' - and utterances[1]['speaker'] == 'ASSISTANT' - and utterances[2]['speaker'] == 'ASSISTANT' - and "help you?" in utterances[1]['text'] - ): - # skip this one - utterances.pop(1) - - user_utterances = [] - asst_utterances = [] - while utterances: - utt = utterances.pop(0) - _, slots = self._segments2text(utt.get('segments', [])) - if utt['speaker'] == 'USER': - if asst_utterances: - yield { - 'text': ' __BREAK__ '.join(user_utterances), - 'label': ' __BREAK__ '.join(asst_utterances), - 'domain': row['domain'], - }, first - first = False - user_utterances = [] - asst_utterances = [] - user_utterances.append(self._delexicalize(utt['text'], slots)) - elif utt['speaker'] == 'ASSISTANT': - asst_utterances.append(self._delexicalize(utt['text'], slots)) - if not user_utterances: - user_utterances.append('__SILENCE__') - if asst_utterances: - yield { - 'text': ' __BREAK__ '.join(user_utterances), - 'label': ' __BREAK__ '.join(asst_utterances), - 'domain': row['domain'], - }, first + domains_cnt[row["domain"]] += 1 + utterances = row["utterances"][:] + + idx = 0 + rounds = [] + goal_calls = [] + if len(utterances) > 0 and utterances[0]["speaker"] == "ASSISTANT": + idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( + "ASSISTANT", utterances, idx + ) + r = tod.TodStructuredRound(api_resp_machine=api_resp, sys_utt=sys_utt) + rounds.append(r) + cum_api_call = {} + while idx < len(utterances): + idx, user_utt, api_call = self._get_utterance_and_api_call_for_speaker( + "USER", utterances, idx + ) + idx, sys_utt, api_resp = self._get_utterance_and_api_call_for_speaker( + "ASSISTANT", utterances, idx + ) + if not self.opt["use_cumulative_api_calls"]: + r = tod.TodStructuredRound( + user_utt=user_utt, + api_call_machine=api_call, + api_resp_machine=api_resp, + sys_utt=sys_utt, + ) + else: + cum_api_call = self.process_call_for_cumlative_standalone_api( + api_call, cum_api_call + ) + r = tod.TodStructuredRound( + user_utt=user_utt, + api_call_machine=cum_api_call if len(api_resp) > 0 else {}, + api_resp_machine=api_resp if len(api_resp) > 0 else {}, + sys_utt=sys_utt, + ) -class TextOnlyTeacher(DelexTeacher): - def _delexicalize(self, text, slots): - return text + rounds.append(r) + if len(api_call) > 0: + goal_calls.append(api_call) + episode = tod.TodStructuredEpisode( + domain=tod.SerializationHelpers.inner_list_join(row["domain"]), + api_schemas_machine=self._get_onto_list(ontologies, row["domain"]), + goal_calls_machine=goal_calls, + rounds=rounds, + delex=self.opt.get("delex", False), + ) + episodes.append(episode) + return episodes -class FullShotTeacher(_Abstract): - """ - The full shot teacher uses a standard 80-10-10 split, without regarding domain. - """ + def get_id_task_prefix(self): + return "Taskmaster2" def _label_fold(self, chunks): return chunks.conversation_id.apply(self._h) + def process_call_for_cumlative_standalone_api(self, new_call, cum_calls): + if ( + len(new_call) > 0 + and len(cum_calls) > 0 + and new_call[tod.STANDARD_API_NAME_SLOT] + != cum_calls[tod.STANDARD_API_NAME_SLOT] + ): + cum_calls = {} + cum_calls.update(new_call) + return cum_calls -class FewShotTeacher(_Abstract): - """ - Few shot teacher tests for generalization to new domains. - """ - @classmethod - def add_cmdline_args( - cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None - ) -> ParlaiParser: - super().add_cmdline_args(parser, partial_opt) - parser.add_argument( - '--holdout', - default=DOMAINS[0], - choices=DOMAINS, - help='Domain which is held out from test', - ) - parser.add_argument( - '--n-shot', - default=100, - type=int, - help='Number of few shot examples to provide in training fold.', - ) - return super().add_cmdline_args(parser, partial_opt=partial_opt) +class UserSimulatorTeacher(Taskmaster2Parser, tod_agents.TodUserSimulatorTeacher): + pass - def _label_fold(self, chunks): - folds = [] - num_shots = 0 - for _, row in chunks.iterrows(): - if row['domain'] != self.opt['holdout']: - # if it's not in the holdout, always mark it train - folds.append('train') - else: - # keep the same valid/test sets as in fullshot, and only leak - # a small number of the training examples (i.e. throw away the - # vast majority of our data but keep test sets the same) - f = self._h(row['conversation_id']) - if f != 'train': - folds.append(f) - elif num_shots < self.opt['n_shot']: - folds.append('train') - num_shots += 1 - else: - folds.append('throwaway') - return folds +class SystemTeacher(Taskmaster2Parser, tod_agents.TodSystemTeacher): + pass -class DefaultTeacher(FullShotTeacher): +class DefaultTeacher(SystemTeacher): pass diff --git a/parlai/tasks/taskmaster2/build.py b/parlai/tasks/taskmaster2/build.py index 23b7a4845e8..1f71a2ae8ed 100644 --- a/parlai/tasks/taskmaster2/build.py +++ b/parlai/tasks/taskmaster2/build.py @@ -9,93 +9,93 @@ import os from parlai.core.build_data import DownloadableFile -ROOT_URL = 'https://github.com/google-research-datasets/Taskmaster/raw/master/TM-2-2020' +ROOT_URL = "https://github.com/google-research-datasets/Taskmaster/raw/master/TM-2-2020" RESOURCES = [ # raw data files DownloadableFile( - f'{ROOT_URL}/data/flights.json', - 'flights.json', - '86b37b5ae25f530fd18ced78800d30c3b54f7b34bb208ecb51842718f04e760b', + f"{ROOT_URL}/data/flights.json", + "flights.json", + "86b37b5ae25f530fd18ced78800d30c3b54f7b34bb208ecb51842718f04e760b", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/data/food-ordering.json', - 'food-ordering.json', - '0a042e566a816a5d0abebe6f7e8cfd6abaa89729ffc42f433d327df7342b12f8', + f"{ROOT_URL}/data/food-ordering.json", + "food-ordering.json", + "0a042e566a816a5d0abebe6f7e8cfd6abaa89729ffc42f433d327df7342b12f8", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/data/hotels.json', - 'hotels.json', - '975b0242f1e37ea1ab94ccedd7e0d6ee5831599d5df1f16143e71110d6c6006a', + f"{ROOT_URL}/data/hotels.json", + "hotels.json", + "975b0242f1e37ea1ab94ccedd7e0d6ee5831599d5df1f16143e71110d6c6006a", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/data/movies.json', - 'movies.json', - '6f67c9a1f04abc111186e5bcfbe3050be01d0737fd6422901402715bc1f3dd0d', + f"{ROOT_URL}/data/movies.json", + "movies.json", + "6f67c9a1f04abc111186e5bcfbe3050be01d0737fd6422901402715bc1f3dd0d", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/data/music.json', - 'music.json', - 'e5db60d6576fa010bef87a70a8b371d293d48cde8524c1d3ed7c3022f079d95d', + f"{ROOT_URL}/data/music.json", + "music.json", + "e5db60d6576fa010bef87a70a8b371d293d48cde8524c1d3ed7c3022f079d95d", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/data/restaurant-search.json', - 'restaurant-search.json', - 'fb9735f89e7ebc7c877f976da4c30391af6a6277991b597c0755564657ff8f47', + f"{ROOT_URL}/data/restaurant-search.json", + "restaurant-search.json", + "fb9735f89e7ebc7c877f976da4c30391af6a6277991b597c0755564657ff8f47", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/data/sports.json', - 'sports.json', - '8191531bfa5a8426b1508c396ab9886a19c7c620b443c436ec10d8d4708d0eac', + f"{ROOT_URL}/data/sports.json", + "sports.json", + "8191531bfa5a8426b1508c396ab9886a19c7c620b443c436ec10d8d4708d0eac", zipped=False, ), # ontology data files DownloadableFile( - f'{ROOT_URL}/ontology/flights.json', - 'flights.onto.json', - '1ebc5c982339d24b2dcf50677883fed65b7fcb95f01edbbd3be6357090893c33', + f"{ROOT_URL}/ontology/flights.json", + "flights.onto.json", + "1ebc5c982339d24b2dcf50677883fed65b7fcb95f01edbbd3be6357090893c33", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/ontology/food-ordering.json', - 'food-ordering.onto.json', - '79c1189c16f0ab937bad558c70a0b9b99358f9ed91ea65ce4af37c4b7d999063', + f"{ROOT_URL}/ontology/food-ordering.json", + "food-ordering.onto.json", + "79c1189c16f0ab937bad558c70a0b9b99358f9ed91ea65ce4af37c4b7d999063", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/ontology/hotels.json', - 'hotels.onto.json', - '22ae51ba546ee7ca03143097782817c4cdd0de74ac84893eaf40b8254aa866d3', + f"{ROOT_URL}/ontology/hotels.json", + "hotels.onto.json", + "22ae51ba546ee7ca03143097782817c4cdd0de74ac84893eaf40b8254aa866d3", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/ontology/movies.json', - 'movies.onto.json', - '8403283526bb314e871850b98bb86a7987ef0af6fbbe4fb5a089ee9498584476', + f"{ROOT_URL}/ontology/movies.json", + "movies.onto.json", + "8403283526bb314e871850b98bb86a7987ef0af6fbbe4fb5a089ee9498584476", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/ontology/music.json', - 'music.onto.json', - '4bcd6dcf1cdc6bdb717e5fdc08b3472dc3d1f4da8a0f8aee917494d79a7fe338', + f"{ROOT_URL}/ontology/music.json", + "music.onto.json", + "4bcd6dcf1cdc6bdb717e5fdc08b3472dc3d1f4da8a0f8aee917494d79a7fe338", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/ontology/restaurant-search.json', - 'restaurant-search.onto.json', - 'c9ead7985695b3feba1fb955e8407d806e4095f5459485adc5448ae89989e609', + f"{ROOT_URL}/ontology/restaurant-search.json", + "restaurant-search.onto.json", + "c9ead7985695b3feba1fb955e8407d806e4095f5459485adc5448ae89989e609", zipped=False, ), DownloadableFile( - f'{ROOT_URL}/ontology/sports.json', - 'sports.onto.json', - '52f9bbb86ebd9e2b3916185ad4e3e9b8b77d2164d96bd3b98ad67cbaa653757d', + f"{ROOT_URL}/ontology/sports.json", + "sports.onto.json", + "52f9bbb86ebd9e2b3916185ad4e3e9b8b77d2164d96bd3b98ad67cbaa653757d", zipped=False, ), ] @@ -103,13 +103,13 @@ def build(opt): # get path to data directory - dpath = os.path.join(opt['datapath'], 'taskmaster-2') + dpath = os.path.join(opt["datapath"], "taskmaster-2") # define version if any version = "1.1" # check if data had been previously built if not build_data.built(dpath, version_string=version): - print('[building data: ' + dpath + ']') + print("[building data: " + dpath + "]") # make a clean directory if needed if build_data.built(dpath): diff --git a/parlai/tasks/taskmaster2/test.py b/parlai/tasks/taskmaster2/test.py new file mode 100644 index 00000000000..e0d2e79a87c --- /dev/null +++ b/parlai/tasks/taskmaster2/test.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from parlai.utils.testing import AutoTeacherTest + + +class TestDefaultTeacher(AutoTeacherTest): + task = "taskmaster2" + + +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "taskmaster2:UserSimulatorTeacher" diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..318b3b416a8 --- /dev/null +++ b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml @@ -0,0 +1,46 @@ +acts: +- - domain: sports + episode_done: false + eval_labels: + - 'USER: Hey, what''s the Denver Broncos record?' + id: Taskmaster2_UserSimulatorTeacher + text: 'GOAL: api_name = nfl ; name.team = Denver Broncos' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: What Conference are they in?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: The Denver Broncos are in currently fourth in the AFC West with + a record of four wins and nine losses.' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: Who do they play against next week?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: The Denver Broncos played in the American Football Conference in + the west division.' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: When did they play last?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: In next week the Denver Broncos will be playing against the Indianapolis + Colts.' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: How many games back from first place are they?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Their last game was yesterday, they beat the New York Jets by 23 + to 0.' + type: 'USER: ' +num_episodes: 1734 +num_examples: 36584 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..16cd94caae9 --- /dev/null +++ b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml @@ -0,0 +1,43 @@ +acts: +- - domain: sports + episode_done: false + id: Taskmaster2_UserSimulatorTeacher + labels: + - 'USER: Hey. How are the Denver Nuggets doing this year?' + text: 'GOAL: api_name = nba ; name.team = Denver Nuggets | api_name = nba ; name.player + = Nikola Jokic' + type: 'USER: ' +- - domain: sports + episode_done: false + id: Taskmaster2_UserSimulatorTeacher + labels: + - 'USER: Okay. And what division are they in?' + slots: {} + text: 'SYSTEM: Hello, They''re currently six place in the Western Conference.' + type: 'USER: ' +- - domain: sports + episode_done: false + id: Taskmaster2_UserSimulatorTeacher + labels: + - 'USER: Okay. And how they did last game?' + slots: {} + text: 'SYSTEM: There in the Northwest division.' + type: 'USER: ' +- - domain: sports + episode_done: false + id: Taskmaster2_UserSimulatorTeacher + labels: + - 'USER: Okay. And I need to start report that.' + slots: {} + text: 'SYSTEM: They lost the last game against the 76ers.' + type: 'USER: ' +- - domain: sports + episode_done: false + id: Taskmaster2_UserSimulatorTeacher + labels: + - 'USER: Okay, And how many points is the college York average in?' + slots: {} + text: 'SYSTEM: Starting point guard is Gary Harris.' + type: 'USER: ' +num_episodes: 13840 +num_examples: 291032 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..97dad21316a --- /dev/null +++ b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml @@ -0,0 +1,45 @@ +acts: +- - domain: sports + episode_done: false + eval_labels: + - 'USER: ' + id: Taskmaster2_UserSimulatorTeacher + text: 'GOAL: api_name = mls ; name.team = Vancouver Whitecaps FC? | api_name = + mls ; day.match = last Saturday | api_name = mls ; position.player = goalkeeper' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: Hi Assistant. How are you?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Hello.' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: I''m great. I''m a big fan of Major League Soccer, And my favorite team + is Vancouver Whitecaps FC. And I would love to know what place they are in, + the Vancouver Whitecaps FC?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: I''m good and yourself?' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: Thank you. And are they playing right now?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Currently in 7th place in the Western Conference.' + type: 'USER: ' +- - domain: sports + episode_done: false + eval_labels: + - 'USER: Okay, thank you. And who did they play last Saturday?' + id: Taskmaster2_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: No, they''re not scheduled to play today.' + type: 'USER: ' +num_episodes: 1730 +num_examples: 36404 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_test.yml b/parlai/tasks/taskmaster2/test/taskmaster2_test.yml new file mode 100644 index 00000000000..eae250af098 --- /dev/null +++ b/parlai/tasks/taskmaster2/test/taskmaster2_test.yml @@ -0,0 +1,55 @@ +acts: +- - domain: sports + episode_done: false + eval_labels: + - 'APIS: ' + id: Taskmaster2_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: sports + episode_done: false + eval_labels: + - 'APICALL: api_name = nfl ; name.team = Denver Broncos' + id: Taskmaster2_SystemTeacher + slots: + api_name: nfl + name.team: Denver Broncos + text: 'USER: Hey, what''s the Denver Broncos record?' + type: 'APICALL: ' +- - domain: sports + episode_done: false + eval_labels: + - 'SYSTEM: The Denver Broncos are in currently fourth in the AFC West with a record + of four wins and nine losses.' + id: Taskmaster2_SystemTeacher + slots: + api_name: nfl + name.team: Denver Broncos + record.team: four wins and nine losses + text: 'APIRESP: api_name = nfl ; name.team = Denver Broncos ; record.team = four + wins and nine losses' + type: 'SYSTEM: ' +- - domain: sports + episode_done: false + eval_labels: + - 'APICALL: api_name = nfl ; name.team = Denver Broncos' + id: Taskmaster2_SystemTeacher + slots: + api_name: nfl + name.team: Denver Broncos + text: 'USER: What Conference are they in?' + type: 'APICALL: ' +- - domain: sports + episode_done: false + eval_labels: + - 'SYSTEM: The Denver Broncos played in the American Football Conference in the + west division.' + id: Taskmaster2_SystemTeacher + slots: + api_name: nfl + name.team: Denver Broncos + text: 'APIRESP: api_name = nfl ; name.team = Denver Broncos' + type: 'SYSTEM: ' +num_episodes: 1734 +num_examples: 36584 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_train.yml b/parlai/tasks/taskmaster2/test/taskmaster2_train.yml new file mode 100644 index 00000000000..5676f053795 --- /dev/null +++ b/parlai/tasks/taskmaster2/test/taskmaster2_train.yml @@ -0,0 +1,48 @@ +acts: +- - domain: sports + episode_done: false + id: Taskmaster2_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: sports + episode_done: false + id: Taskmaster2_SystemTeacher + labels: + - 'APICALL: api_name = nba ; name.team = Denver Nuggets' + slots: + api_name: nba + name.player: Nikola Jokic + name.team: Denver Nuggets + text: 'USER: Hey. How are the Denver Nuggets doing this year?' + type: 'APICALL: ' +- - domain: sports + episode_done: false + id: Taskmaster2_SystemTeacher + labels: + - 'SYSTEM: Hello, They''re currently six place in the Western Conference.' + slots: + api_name: nba + place.team: six place + text: 'APIRESP: api_name = nba ; place.team = six place' + type: 'SYSTEM: ' +- - domain: sports + episode_done: false + id: Taskmaster2_SystemTeacher + labels: + - 'APICALL: ' + slots: {} + text: 'USER: Okay. And what division are they in?' + type: 'APICALL: ' +- - domain: sports + episode_done: false + id: Taskmaster2_SystemTeacher + labels: + - 'SYSTEM: There in the Northwest division.' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 13840 +num_examples: 291032 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml b/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml new file mode 100644 index 00000000000..3a1c36b2bdb --- /dev/null +++ b/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml @@ -0,0 +1,43 @@ +acts: +- - domain: sports + episode_done: false + eval_labels: + - 'APIS: ' + id: Taskmaster2_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: sports + episode_done: false + eval_labels: + - 'APICALL: ' + id: Taskmaster2_SystemTeacher + slots: {} + text: 'USER: ' + type: 'APICALL: ' +- - domain: sports + episode_done: false + eval_labels: + - 'SYSTEM: Hello.' + id: Taskmaster2_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: sports + episode_done: false + eval_labels: + - 'APICALL: ' + id: Taskmaster2_SystemTeacher + slots: {} + text: 'USER: Hi Assistant. How are you?' + type: 'APICALL: ' +- - domain: sports + episode_done: false + eval_labels: + - 'SYSTEM: I''m good and yourself?' + id: Taskmaster2_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 1730 +num_examples: 36404 From 3675781fb5c7f5f9adba8c84f997a063f6c123cb Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 14:55:47 -0800 Subject: [PATCH 21/57] use same version of black as in the pre-commit hook --- parlai/core/tod/tod_agents.py | 11 ++-------- parlai/core/tod/tod_test_utils/test_agents.py | 9 +-------- tests/tod/test_tod_agents_and_teachers.py | 20 ++++--------------- tests/tod/test_tod_teacher_metrics.py | 5 +---- 4 files changed, 8 insertions(+), 37 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) From 0bc961e66e56250703b27cccb8e9c5f2e6ac2528 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 14:56:22 -0800 Subject: [PATCH 22/57] use same version of black as in the pre-commit hook --- parlai/core/tod/world_metrics.py | 7 ++----- parlai/core/tod/world_metrics_handlers.py | 9 ++------- parlai/scripts/distributed_tod_world_script.py | 4 +--- parlai/scripts/tod_world_script.py | 8 ++------ tests/tod/test_tod_world_metrics.py | 8 ++------ tests/tod/test_tod_world_metrics_in_script.py | 11 +++++------ tests/tod/test_tod_world_script_metrics.py | 4 +--- 7 files changed, 15 insertions(+), 36 deletions(-) diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From 24ee8984e6b655b52c849579e28efa7f498afee4 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 15:06:19 -0800 Subject: [PATCH 23/57] black with version from pre-commit hook --- parlai/tasks/tod_json/agents.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser From 3145e0edcc09123b8e757fc1bd150f31bc039fd2 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 15:11:13 -0800 Subject: [PATCH 24/57] Shouldn't worry about tod_json being in task_list --- tests/test_zootasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_zootasks.py b/tests/test_zootasks.py index e5dc7023560..57ac70d90b5 100644 --- a/tests/test_zootasks.py +++ b/tests/test_zootasks.py @@ -76,7 +76,7 @@ def test_tasklist(self): task_list, "parlai/tasks", "task", - ignore=['fromfile', 'interactive', 'jsonfile', 'wrapper'], + ignore=['fromfile', 'interactive', 'jsonfile', 'tod_json', 'wrapper'], ) @pytest.mark.nofbcode From f44b17b4bb4e8eafb24baece30b154f6bea4cd64 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 15:15:59 -0800 Subject: [PATCH 25/57] add to task list; run lint with right version of black --- parlai/core/tod/tod_agents.py | 11 ++-------- parlai/core/tod/tod_test_utils/test_agents.py | 9 +-------- parlai/core/tod/world_metrics.py | 7 ++----- parlai/core/tod/world_metrics_handlers.py | 9 ++------- .../scripts/distributed_tod_world_script.py | 4 +--- parlai/scripts/tod_world_script.py | 8 ++------ parlai/tasks/task_list.py | 11 ++++++++++ parlai/tasks/tod_json/agents.py | 5 +---- tests/tod/test_tod_agents_and_teachers.py | 20 ++++--------------- tests/tod/test_tod_teacher_metrics.py | 5 +---- tests/tod/test_tod_world_metrics.py | 8 ++------ tests/tod/test_tod_world_metrics_in_script.py | 11 +++++----- tests/tod/test_tod_world_script_metrics.py | 4 +--- 13 files changed, 35 insertions(+), 77 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/parlai/tasks/task_list.py b/parlai/tasks/task_list.py index 295986ccb5d..97b38f15091 100644 --- a/parlai/tasks/task_list.py +++ b/parlai/tasks/task_list.py @@ -1354,6 +1354,17 @@ "human and a virtual assistant." ), }, + { + "id": "GoogleSGDSimulationSplits", + "display_name": "GoogleSGD Simulation Splits", + "task": "google_sgd_simulation_splits", + "tags": ["Goal"], + "description": ( + "Custom processing of the Google SGD dataset into In-Domain and " + "Out-of-Domain splits for the use of zero and few-shotting with " + "other task-oriented data." + ), + }, { "id": "TaskMaster2", "display_name": "TaskMaster2", diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From 7c3ccf50128e6f83676ed79cd2624ce6d848b93f Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 08:31:22 -0800 Subject: [PATCH 26/57] lint with right version --- parlai/tasks/metalwoz/agents.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/parlai/tasks/metalwoz/agents.py b/parlai/tasks/metalwoz/agents.py index f3f9d2c8238..c4ddab9cdbb 100644 --- a/parlai/tasks/metalwoz/agents.py +++ b/parlai/tasks/metalwoz/agents.py @@ -22,9 +22,7 @@ def add_cmdline_args( ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) parser.add_argument( - "--metalwoz-domains", - nargs="+", - help="Use only a subset of the domains", + "--metalwoz-domains", nargs="+", help="Use only a subset of the domains" ) return parser @@ -104,9 +102,10 @@ def setup_data(self, datapath): data = self.load_data(datapath) for row in data: texts = list(row["turns"]) - prompts, labels = [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], texts[ - 1::2 - ] + prompts, labels = ( + [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], + texts[1::2], + ) for i, (prompt, label) in enumerate(zip(prompts, labels)): yield { "text": prompt, From ab19cc20a05a4ab19a228298f10e7f71eef8362b Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 08:34:38 -0800 Subject: [PATCH 27/57] lint --- parlai/core/tod/tod_agents.py | 11 ++-------- parlai/core/tod/tod_test_utils/test_agents.py | 9 +-------- parlai/core/tod/world_metrics.py | 7 ++----- parlai/core/tod/world_metrics_handlers.py | 9 ++------- .../scripts/distributed_tod_world_script.py | 4 +--- parlai/scripts/tod_world_script.py | 8 ++------ parlai/tasks/metalwoz/agents.py | 11 +++++----- parlai/tasks/msr_e2e/agents.py | 6 +----- parlai/tasks/tod_json/agents.py | 5 +---- tests/tod/test_tod_agents_and_teachers.py | 20 ++++--------------- tests/tod/test_tod_teacher_metrics.py | 5 +---- tests/tod/test_tod_world_metrics.py | 8 ++------ tests/tod/test_tod_world_metrics_in_script.py | 11 +++++----- tests/tod/test_tod_world_script_metrics.py | 4 +--- 14 files changed, 30 insertions(+), 88 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/parlai/tasks/metalwoz/agents.py b/parlai/tasks/metalwoz/agents.py index f3f9d2c8238..c4ddab9cdbb 100644 --- a/parlai/tasks/metalwoz/agents.py +++ b/parlai/tasks/metalwoz/agents.py @@ -22,9 +22,7 @@ def add_cmdline_args( ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) parser.add_argument( - "--metalwoz-domains", - nargs="+", - help="Use only a subset of the domains", + "--metalwoz-domains", nargs="+", help="Use only a subset of the domains" ) return parser @@ -104,9 +102,10 @@ def setup_data(self, datapath): data = self.load_data(datapath) for row in data: texts = list(row["turns"]) - prompts, labels = [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], texts[ - 1::2 - ] + prompts, labels = ( + [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], + texts[1::2], + ) for i, (prompt, label) in enumerate(zip(prompts, labels)): yield { "text": prompt, diff --git a/parlai/tasks/msr_e2e/agents.py b/parlai/tasks/msr_e2e/agents.py index aa7b13eb04d..293cdc5cb88 100644 --- a/parlai/tasks/msr_e2e/agents.py +++ b/parlai/tasks/msr_e2e/agents.py @@ -28,11 +28,7 @@ import parlai.core.tod.tod_agents as tod_agents -DOMAINS = [ - "movie", - "restaurant", - "taxi", -] +DOMAINS = ["movie", "restaurant", "taxi"] # Just going to copy/paste these since it's faster than parsing 3 separate files # They are in `system/src/deep_dialog/data_/_slots.txt` in the original data diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From 724b255265bf0aa5e3a7b80af2978b4c83814ad3 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 08:53:23 -0800 Subject: [PATCH 28/57] add to task list --- parlai/tasks/task_list.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/parlai/tasks/task_list.py b/parlai/tasks/task_list.py index 97b38f15091..bceaeb9aeba 100644 --- a/parlai/tasks/task_list.py +++ b/parlai/tasks/task_list.py @@ -876,6 +876,18 @@ ), "links": {"website": "https://ai.google/tools/datasets/taskmaster-1"}, }, + { + "id": "MSR-E2E", + "display_name": "MSR End-to-End", + "task": "msr_e2e", + "tags": ["ChitChat"], + "description": ( + "MSR-E2E is a dataset of human-human conversations in which one " + "human plays the role of an Agent and the other one plays the role" + "of a User. Data is collected from Amazon Mechanical Turk. " + ), + "links": {"website": "https://github.com/xiul-msr/e2e_dialog_challenge"}, + }, { "id": "Twitter", "display_name": "Twitter", From c79422c49cebf686eb38768cf9938acd322d0e3b Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 08:54:53 -0800 Subject: [PATCH 29/57] lint right version --- parlai/core/tod/tod_agents.py | 11 ++-------- parlai/core/tod/tod_test_utils/test_agents.py | 9 +-------- parlai/core/tod/world_metrics.py | 7 ++----- parlai/core/tod/world_metrics_handlers.py | 9 ++------- .../scripts/distributed_tod_world_script.py | 4 +--- parlai/scripts/tod_world_script.py | 8 ++------ parlai/tasks/metalwoz/agents.py | 11 +++++----- parlai/tasks/msr_e2e/agents.py | 6 +----- parlai/tasks/tod_json/agents.py | 5 +---- tests/tod/test_tod_agents_and_teachers.py | 20 ++++--------------- tests/tod/test_tod_teacher_metrics.py | 5 +---- tests/tod/test_tod_world_metrics.py | 8 ++------ tests/tod/test_tod_world_metrics_in_script.py | 11 +++++----- tests/tod/test_tod_world_script_metrics.py | 4 +--- 14 files changed, 30 insertions(+), 88 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/parlai/tasks/metalwoz/agents.py b/parlai/tasks/metalwoz/agents.py index f3f9d2c8238..c4ddab9cdbb 100644 --- a/parlai/tasks/metalwoz/agents.py +++ b/parlai/tasks/metalwoz/agents.py @@ -22,9 +22,7 @@ def add_cmdline_args( ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) parser.add_argument( - "--metalwoz-domains", - nargs="+", - help="Use only a subset of the domains", + "--metalwoz-domains", nargs="+", help="Use only a subset of the domains" ) return parser @@ -104,9 +102,10 @@ def setup_data(self, datapath): data = self.load_data(datapath) for row in data: texts = list(row["turns"]) - prompts, labels = [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], texts[ - 1::2 - ] + prompts, labels = ( + [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], + texts[1::2], + ) for i, (prompt, label) in enumerate(zip(prompts, labels)): yield { "text": prompt, diff --git a/parlai/tasks/msr_e2e/agents.py b/parlai/tasks/msr_e2e/agents.py index aa7b13eb04d..293cdc5cb88 100644 --- a/parlai/tasks/msr_e2e/agents.py +++ b/parlai/tasks/msr_e2e/agents.py @@ -28,11 +28,7 @@ import parlai.core.tod.tod_agents as tod_agents -DOMAINS = [ - "movie", - "restaurant", - "taxi", -] +DOMAINS = ["movie", "restaurant", "taxi"] # Just going to copy/paste these since it's faster than parsing 3 separate files # They are in `system/src/deep_dialog/data_/_slots.txt` in the original data diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From c86760c0adaf48bf76cf6f63d9badb38058f4e50 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 08:57:47 -0800 Subject: [PATCH 30/57] add to task list --- parlai/tasks/task_list.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/parlai/tasks/task_list.py b/parlai/tasks/task_list.py index bceaeb9aeba..66d11f4b669 100644 --- a/parlai/tasks/task_list.py +++ b/parlai/tasks/task_list.py @@ -1175,6 +1175,20 @@ ), "links": {"arXiv": "https://arxiv.org/abs/1908.06083"}, }, + { + "id": "MultiDoGo", + "display_name": "MultiDoGo", + "task": "multidogo", + "tags": ["TOD"], + "description": ( + "MultiDoGo is a large task-oriented dataset from Amazon collected " + "in a Wizard of Oz fashion, using both crowd and expert annotators " + "with annotations at varying levels of granularity." + ), + "links": { + "website": "https://github.com/awslabs/multi-domain-goal-oriented-dialogues-dataset" + }, + }, { "id": "MultiWOZv2.0", "display_name": "MultiWOZ 2.0", From 625e632faf14d174ec7c389b252376d77f288968 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 08:58:43 -0800 Subject: [PATCH 31/57] right lint --- parlai/core/tod/tod_agents.py | 11 ++----- parlai/core/tod/tod_test_utils/test_agents.py | 9 +----- parlai/core/tod/world_metrics.py | 7 ++--- parlai/core/tod/world_metrics_handlers.py | 9 ++---- .../scripts/distributed_tod_world_script.py | 4 +-- parlai/scripts/tod_world_script.py | 8 ++--- parlai/tasks/metalwoz/agents.py | 11 ++++--- parlai/tasks/msr_e2e/agents.py | 6 +--- parlai/tasks/multiwoz_v22/agents.py | 30 ++++++++----------- parlai/tasks/tod_json/agents.py | 5 +--- tests/tod/test_tod_agents_and_teachers.py | 20 +++---------- tests/tod/test_tod_teacher_metrics.py | 5 +--- tests/tod/test_tod_world_metrics.py | 8 ++--- tests/tod/test_tod_world_metrics_in_script.py | 11 ++++--- tests/tod/test_tod_world_script_metrics.py | 4 +-- 15 files changed, 43 insertions(+), 105 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/parlai/tasks/metalwoz/agents.py b/parlai/tasks/metalwoz/agents.py index f3f9d2c8238..c4ddab9cdbb 100644 --- a/parlai/tasks/metalwoz/agents.py +++ b/parlai/tasks/metalwoz/agents.py @@ -22,9 +22,7 @@ def add_cmdline_args( ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) parser.add_argument( - "--metalwoz-domains", - nargs="+", - help="Use only a subset of the domains", + "--metalwoz-domains", nargs="+", help="Use only a subset of the domains" ) return parser @@ -104,9 +102,10 @@ def setup_data(self, datapath): data = self.load_data(datapath) for row in data: texts = list(row["turns"]) - prompts, labels = [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], texts[ - 1::2 - ] + prompts, labels = ( + [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], + texts[1::2], + ) for i, (prompt, label) in enumerate(zip(prompts, labels)): yield { "text": prompt, diff --git a/parlai/tasks/msr_e2e/agents.py b/parlai/tasks/msr_e2e/agents.py index aa7b13eb04d..293cdc5cb88 100644 --- a/parlai/tasks/msr_e2e/agents.py +++ b/parlai/tasks/msr_e2e/agents.py @@ -28,11 +28,7 @@ import parlai.core.tod.tod_agents as tod_agents -DOMAINS = [ - "movie", - "restaurant", - "taxi", -] +DOMAINS = ["movie", "restaurant", "taxi"] # Just going to copy/paste these since it's faster than parsing 3 separate files # They are in `system/src/deep_dialog/data_/_slots.txt` in the original data diff --git a/parlai/tasks/multiwoz_v22/agents.py b/parlai/tasks/multiwoz_v22/agents.py index e59942d1f3a..4d6e3be69be 100644 --- a/parlai/tasks/multiwoz_v22/agents.py +++ b/parlai/tasks/multiwoz_v22/agents.py @@ -34,14 +34,7 @@ "train", ] -WELL_FORMATTED_DOMAINS = [ - "attraction", - "bus", - "hotel", - "restaurant", - "train", - "taxi", -] +WELL_FORMATTED_DOMAINS = ["attraction", "bus", "hotel", "restaurant", "train", "taxi"] class MultiwozV22Parser(tod_agents.TodStructuredDataParser): @@ -93,9 +86,7 @@ def load_schemas(self): for intent in service["intents"]: call_name = intent["name"] - result[call_name] = { - tod.STANDARD_API_NAME_SLOT: call_name, - } + result[call_name] = {tod.STANDARD_API_NAME_SLOT: call_name} req_slots = set([x[prefix_end_idx:] for x in intent["required_slots"]]) if len(req_slots) > 0: result[call_name][tod.STANDARD_REQUIRED_KEY] = list(req_slots) @@ -286,11 +277,14 @@ def _get_round(self, dialogue_id, raw_episode, turn_id): resp = {} if len(call) > 0: self.last_call = call - return call, tod.TodStructuredRound( - user_utt=user_turn["utterance"], - api_call_machine=call, - api_resp_machine=resp, - sys_utt=sys_turn["utterance"], + return ( + call, + tod.TodStructuredRound( + user_utt=user_turn["utterance"], + api_call_machine=call, + api_resp_machine=resp, + sys_utt=sys_turn["utterance"], + ), ) def _get_schemas_for_goal_calls(self, goals): @@ -322,7 +316,9 @@ def setup_episodes(self, fold): if raw_episode["dialogue_id"] != self.opt["dialogue_id"]: continue - skip = False # need to skip outer for loop while in `for domains` inner for loop + skip = ( + False + ) # need to skip outer for loop while in `for domains` inner for loop if self.opt.get("well_formatted_domains_only", True): if len(domains) == 0: skip = True diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From dcadadbc2e3277bbaaa5af81e617e38e640d0b02 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 09:00:12 -0800 Subject: [PATCH 32/57] task_list --- parlai/tasks/task_list.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/parlai/tasks/task_list.py b/parlai/tasks/task_list.py index 66d11f4b669..1a55c67059e 100644 --- a/parlai/tasks/task_list.py +++ b/parlai/tasks/task_list.py @@ -1211,6 +1211,19 @@ ), "links": {"website": "http://dialogue.mi.eng.cam.ac.uk/index.php/corpus/"}, }, + { + "id": "MultiWOZv2.2", + "display_name": "MultiWOZ 2.2", + "task": "multiwoz_v22", + "tags": ["Goal"], + "description": ( + "A fully labeled collection of human-written conversations spanning" + "over multiple domains and topics. Schemas are included." + ), + "links": { + "website": "https://github.com/budzianowski/multiwoz/tree/master/data/MultiWOZ_2.2" + }, + }, { "id": "SelfChat", "display_name": "SelfChat", From fff653f34540d3ef08868921a485a23cadf8afd2 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 09:00:40 -0800 Subject: [PATCH 33/57] right lint --- parlai/core/tod/tod_agents.py | 11 ++----- parlai/core/tod/tod_test_utils/test_agents.py | 9 +----- parlai/core/tod/world_metrics.py | 7 ++--- parlai/core/tod/world_metrics_handlers.py | 9 ++---- .../scripts/distributed_tod_world_script.py | 4 +-- parlai/scripts/tod_world_script.py | 8 ++--- parlai/tasks/metalwoz/agents.py | 11 ++++--- parlai/tasks/msr_e2e/agents.py | 6 +--- parlai/tasks/multiwoz_v22/agents.py | 30 ++++++++----------- parlai/tasks/taskmaster/agents.py | 10 +++---- parlai/tasks/tod_json/agents.py | 5 +--- tests/tod/test_tod_agents_and_teachers.py | 20 +++---------- tests/tod/test_tod_teacher_metrics.py | 5 +--- tests/tod/test_tod_world_metrics.py | 8 ++--- tests/tod/test_tod_world_metrics_in_script.py | 11 ++++--- tests/tod/test_tod_world_script_metrics.py | 4 +-- 16 files changed, 47 insertions(+), 111 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/parlai/tasks/metalwoz/agents.py b/parlai/tasks/metalwoz/agents.py index f3f9d2c8238..c4ddab9cdbb 100644 --- a/parlai/tasks/metalwoz/agents.py +++ b/parlai/tasks/metalwoz/agents.py @@ -22,9 +22,7 @@ def add_cmdline_args( ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) parser.add_argument( - "--metalwoz-domains", - nargs="+", - help="Use only a subset of the domains", + "--metalwoz-domains", nargs="+", help="Use only a subset of the domains" ) return parser @@ -104,9 +102,10 @@ def setup_data(self, datapath): data = self.load_data(datapath) for row in data: texts = list(row["turns"]) - prompts, labels = [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], texts[ - 1::2 - ] + prompts, labels = ( + [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], + texts[1::2], + ) for i, (prompt, label) in enumerate(zip(prompts, labels)): yield { "text": prompt, diff --git a/parlai/tasks/msr_e2e/agents.py b/parlai/tasks/msr_e2e/agents.py index aa7b13eb04d..293cdc5cb88 100644 --- a/parlai/tasks/msr_e2e/agents.py +++ b/parlai/tasks/msr_e2e/agents.py @@ -28,11 +28,7 @@ import parlai.core.tod.tod_agents as tod_agents -DOMAINS = [ - "movie", - "restaurant", - "taxi", -] +DOMAINS = ["movie", "restaurant", "taxi"] # Just going to copy/paste these since it's faster than parsing 3 separate files # They are in `system/src/deep_dialog/data_/_slots.txt` in the original data diff --git a/parlai/tasks/multiwoz_v22/agents.py b/parlai/tasks/multiwoz_v22/agents.py index e59942d1f3a..4d6e3be69be 100644 --- a/parlai/tasks/multiwoz_v22/agents.py +++ b/parlai/tasks/multiwoz_v22/agents.py @@ -34,14 +34,7 @@ "train", ] -WELL_FORMATTED_DOMAINS = [ - "attraction", - "bus", - "hotel", - "restaurant", - "train", - "taxi", -] +WELL_FORMATTED_DOMAINS = ["attraction", "bus", "hotel", "restaurant", "train", "taxi"] class MultiwozV22Parser(tod_agents.TodStructuredDataParser): @@ -93,9 +86,7 @@ def load_schemas(self): for intent in service["intents"]: call_name = intent["name"] - result[call_name] = { - tod.STANDARD_API_NAME_SLOT: call_name, - } + result[call_name] = {tod.STANDARD_API_NAME_SLOT: call_name} req_slots = set([x[prefix_end_idx:] for x in intent["required_slots"]]) if len(req_slots) > 0: result[call_name][tod.STANDARD_REQUIRED_KEY] = list(req_slots) @@ -286,11 +277,14 @@ def _get_round(self, dialogue_id, raw_episode, turn_id): resp = {} if len(call) > 0: self.last_call = call - return call, tod.TodStructuredRound( - user_utt=user_turn["utterance"], - api_call_machine=call, - api_resp_machine=resp, - sys_utt=sys_turn["utterance"], + return ( + call, + tod.TodStructuredRound( + user_utt=user_turn["utterance"], + api_call_machine=call, + api_resp_machine=resp, + sys_utt=sys_turn["utterance"], + ), ) def _get_schemas_for_goal_calls(self, goals): @@ -322,7 +316,9 @@ def setup_episodes(self, fold): if raw_episode["dialogue_id"] != self.opt["dialogue_id"]: continue - skip = False # need to skip outer for loop while in `for domains` inner for loop + skip = ( + False + ) # need to skip outer for loop while in `for domains` inner for loop if self.opt.get("well_formatted_domains_only", True): if len(domains) == 0: skip = True diff --git a/parlai/tasks/taskmaster/agents.py b/parlai/tasks/taskmaster/agents.py index 237a9fd4a10..dfea1277dda 100644 --- a/parlai/tasks/taskmaster/agents.py +++ b/parlai/tasks/taskmaster/agents.py @@ -200,7 +200,7 @@ def setup_episodes(self, fold): rounds = [] goal_calls = [] if len(utterances) > 0 and utterances[0]["speaker"] == "ASSISTANT": - (idx, sys_utt, _,) = self._get_utterance_and_slots_for_speaker( + (idx, sys_utt, _) = self._get_utterance_and_slots_for_speaker( "ASSISTANT", utterances, idx ) @@ -209,11 +209,9 @@ def setup_episodes(self, fold): rounds.append(t) while idx < len(utterances): - ( - idx, - user_utt, - user_slots, - ) = self._get_utterance_and_slots_for_speaker("USER", utterances, idx) + (idx, user_utt, user_slots) = self._get_utterance_and_slots_for_speaker( + "USER", utterances, idx + ) ( idx, sys_utt, diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From bbc10cab2a4b50f83bcd84e3c3622ebc7cd41467 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 09:01:15 -0800 Subject: [PATCH 34/57] right lint --- parlai/core/tod/tod_agents.py | 11 ++----- parlai/core/tod/tod_test_utils/test_agents.py | 9 +----- parlai/core/tod/world_metrics.py | 7 ++--- parlai/core/tod/world_metrics_handlers.py | 9 ++---- .../scripts/distributed_tod_world_script.py | 4 +-- parlai/scripts/tod_world_script.py | 8 ++--- parlai/tasks/metalwoz/agents.py | 11 ++++--- parlai/tasks/msr_e2e/agents.py | 6 +--- parlai/tasks/multiwoz_v22/agents.py | 30 ++++++++----------- parlai/tasks/taskmaster/agents.py | 10 +++---- parlai/tasks/tod_json/agents.py | 5 +--- tests/tod/test_tod_agents_and_teachers.py | 20 +++---------- tests/tod/test_tod_teacher_metrics.py | 5 +--- tests/tod/test_tod_world_metrics.py | 8 ++--- tests/tod/test_tod_world_metrics_in_script.py | 11 ++++--- tests/tod/test_tod_world_script_metrics.py | 4 +-- 16 files changed, 47 insertions(+), 111 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/parlai/tasks/metalwoz/agents.py b/parlai/tasks/metalwoz/agents.py index f3f9d2c8238..c4ddab9cdbb 100644 --- a/parlai/tasks/metalwoz/agents.py +++ b/parlai/tasks/metalwoz/agents.py @@ -22,9 +22,7 @@ def add_cmdline_args( ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) parser.add_argument( - "--metalwoz-domains", - nargs="+", - help="Use only a subset of the domains", + "--metalwoz-domains", nargs="+", help="Use only a subset of the domains" ) return parser @@ -104,9 +102,10 @@ def setup_data(self, datapath): data = self.load_data(datapath) for row in data: texts = list(row["turns"]) - prompts, labels = [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], texts[ - 1::2 - ] + prompts, labels = ( + [f"{row['user_role']}\n{texts[0]}"] + texts[2::2], + texts[1::2], + ) for i, (prompt, label) in enumerate(zip(prompts, labels)): yield { "text": prompt, diff --git a/parlai/tasks/msr_e2e/agents.py b/parlai/tasks/msr_e2e/agents.py index aa7b13eb04d..293cdc5cb88 100644 --- a/parlai/tasks/msr_e2e/agents.py +++ b/parlai/tasks/msr_e2e/agents.py @@ -28,11 +28,7 @@ import parlai.core.tod.tod_agents as tod_agents -DOMAINS = [ - "movie", - "restaurant", - "taxi", -] +DOMAINS = ["movie", "restaurant", "taxi"] # Just going to copy/paste these since it's faster than parsing 3 separate files # They are in `system/src/deep_dialog/data_/_slots.txt` in the original data diff --git a/parlai/tasks/multiwoz_v22/agents.py b/parlai/tasks/multiwoz_v22/agents.py index e59942d1f3a..4d6e3be69be 100644 --- a/parlai/tasks/multiwoz_v22/agents.py +++ b/parlai/tasks/multiwoz_v22/agents.py @@ -34,14 +34,7 @@ "train", ] -WELL_FORMATTED_DOMAINS = [ - "attraction", - "bus", - "hotel", - "restaurant", - "train", - "taxi", -] +WELL_FORMATTED_DOMAINS = ["attraction", "bus", "hotel", "restaurant", "train", "taxi"] class MultiwozV22Parser(tod_agents.TodStructuredDataParser): @@ -93,9 +86,7 @@ def load_schemas(self): for intent in service["intents"]: call_name = intent["name"] - result[call_name] = { - tod.STANDARD_API_NAME_SLOT: call_name, - } + result[call_name] = {tod.STANDARD_API_NAME_SLOT: call_name} req_slots = set([x[prefix_end_idx:] for x in intent["required_slots"]]) if len(req_slots) > 0: result[call_name][tod.STANDARD_REQUIRED_KEY] = list(req_slots) @@ -286,11 +277,14 @@ def _get_round(self, dialogue_id, raw_episode, turn_id): resp = {} if len(call) > 0: self.last_call = call - return call, tod.TodStructuredRound( - user_utt=user_turn["utterance"], - api_call_machine=call, - api_resp_machine=resp, - sys_utt=sys_turn["utterance"], + return ( + call, + tod.TodStructuredRound( + user_utt=user_turn["utterance"], + api_call_machine=call, + api_resp_machine=resp, + sys_utt=sys_turn["utterance"], + ), ) def _get_schemas_for_goal_calls(self, goals): @@ -322,7 +316,9 @@ def setup_episodes(self, fold): if raw_episode["dialogue_id"] != self.opt["dialogue_id"]: continue - skip = False # need to skip outer for loop while in `for domains` inner for loop + skip = ( + False + ) # need to skip outer for loop while in `for domains` inner for loop if self.opt.get("well_formatted_domains_only", True): if len(domains) == 0: skip = True diff --git a/parlai/tasks/taskmaster/agents.py b/parlai/tasks/taskmaster/agents.py index 237a9fd4a10..dfea1277dda 100644 --- a/parlai/tasks/taskmaster/agents.py +++ b/parlai/tasks/taskmaster/agents.py @@ -200,7 +200,7 @@ def setup_episodes(self, fold): rounds = [] goal_calls = [] if len(utterances) > 0 and utterances[0]["speaker"] == "ASSISTANT": - (idx, sys_utt, _,) = self._get_utterance_and_slots_for_speaker( + (idx, sys_utt, _) = self._get_utterance_and_slots_for_speaker( "ASSISTANT", utterances, idx ) @@ -209,11 +209,9 @@ def setup_episodes(self, fold): rounds.append(t) while idx < len(utterances): - ( - idx, - user_utt, - user_slots, - ) = self._get_utterance_and_slots_for_speaker("USER", utterances, idx) + (idx, user_utt, user_slots) = self._get_utterance_and_slots_for_speaker( + "USER", utterances, idx + ) ( idx, sys_utt, diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From e8e366f2ef11a265d257866060f9ee385391d5cc Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 17 Nov 2021 09:10:13 -0800 Subject: [PATCH 35/57] add init --- parlai/tasks/msr_e2e/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 parlai/tasks/msr_e2e/__init__.py diff --git a/parlai/tasks/msr_e2e/__init__.py b/parlai/tasks/msr_e2e/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/tasks/msr_e2e/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. From 2f1544806dbf6daf2ea60c0bf196352cce3e3e4c Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 08:45:59 -0800 Subject: [PATCH 36/57] address eric comments; add new readme + more documentation --- parlai/core/tod/README.md | 81 +++++++++++++++++++++++ parlai/core/tod/teacher_metrics.py | 11 --- parlai/core/tod/tod_agents.py | 55 ++------------- tests/tod/test_tod_agents_and_teachers.py | 65 +++++++++++++++++- tests/tod/test_tod_teacher_metrics.py | 4 ++ 5 files changed, 154 insertions(+), 62 deletions(-) create mode 100644 parlai/core/tod/README.md diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md new file mode 100644 index 00000000000..d2ffbe33570 --- /dev/null +++ b/parlai/core/tod/README.md @@ -0,0 +1,81 @@ +# Tod Core README + +For the quickest getting-to-use of core classes, start with the "Teachers + Agents Usage" section below (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format). + +See `projects/tod_simulator/README` for a higher-level usage-focused README. This document also describes the structure of the contents of this directory. + +As a convention, files referenced externally to this directory are prefixed with `tod` whereas those only referenced by other files within the directory are not. + +--- + +# Teachers + Agents Usage + +See `tod_agents.py` for the classes. + +For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. + +For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. + +Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for generating the simulations themselves. + +## Why we do this +These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. + +# Description of Agents + Teachers useful for Simulation +## Teachers for training (generative) models + * TodSystemTeacher + * TodUserSimulatorTeacher + +## Agents for Grounding +For goal grounding for the User for simulation: + * TodGoalAgent + * Dumps goals as is from the dataset, possibly multiple per episode + * TodSingleGoalAgent + * Flattens goals such that a single one is used to seed a conversation. For datasets that include multiple goals per conversation, each individual goal is used as a seed. + +For (optional) API schema grounding for the System: + * TodApiSchemaAgent (must be used with `TodGoalAgent` only) + * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) + * EmptyApiSchemaAgent + * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. + +## Agents for mocking APIs: + * StandaloneApiAgent + * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher`. (See comments in-line on classes for train command example) + +# Agents for dumping data from a ground truth dataset +The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. + * TodApiCallAndSysUttAgent + * TodApiResponseAgent + * TodUserUttAgent + +For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. + +# Other agents +There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. + +--- + +# Directory contents + +This directory is split into 3 main components: files to support agents + teachers, files to support the simulation world, and files to store functionality common to both of these. We describe the common functionality first then go to the other two. + +Tests for all files in this directory are stored in `tests/tod` + +## Files for common functionality +`tod_core.py` defines consts and enums used across TOD agents, teachers, and world. It also defines dataclasses for storing the intermediate data format used when parsing a dataset to the TOD structure as well as a `SerializationHelper` from going from machine structured data (ex. API Calls) to flattened versions used by the models. + + +## Files for agents and teachers +Usage of `tod_agents.py` is described above. It references `teacher_metrics.py` which stores Metrics objects. + +## Files for simulation world +Description of usage of the simulation world is primarily stored in the script running the world, stored in `parlai/scripts/tod_world_script.py`. The script is responsible for running multiple episodes of simulation and saving simulation output data. + +The world itself is stored in `tod_world.py`. The world follows the same intermediate dataformats for episodes as described in `tod_core.py` and does the correct calling of different agents to support this. It is generally recommended that this file not be touched. + +A general class for collecting metrics out of `TODWorld` is stored within `world_metrics.py` with individual 'metric handlers' responsible for calculating a given metric stored in `world_metric_handlers.py`. + + + + diff --git a/parlai/core/tod/teacher_metrics.py b/parlai/core/tod/teacher_metrics.py index 3fc85c7d107..83c140a3ea5 100644 --- a/parlai/core/tod/teacher_metrics.py +++ b/parlai/core/tod/teacher_metrics.py @@ -18,10 +18,6 @@ class SlotMetrics(Metrics): Due to differences in dialogue representations between tasks, the input is pre- parsed ground truth and predicted slot dictionaries. - - The 'jga+nlg' metric assumes a balanced set of JGA and NLG scores such that - 2 * Avg(JGA, NLG_BLEU) = Avg(JGA + NLG_BLEU) - The `jga+nlg` metric assumes that `NlgMetrics` is used to calculated the other side. """ def __init__( @@ -30,7 +26,6 @@ def __init__( predicted_slots: Dict[str, str], prefixes: Optional[List] = None, shared: Dict[str, Any] = None, - avg_jga_nlg_bleu: bool = False, ) -> None: super().__init__(shared=shared) self.prefixes = prefixes if prefixes else [] @@ -45,9 +40,6 @@ def __init__( "jga_empty", AverageMetric(teacher_slots == predicted_slots) ) - if avg_jga_nlg_bleu: - # add one half of Avg(jga,nlg_bleu), NlgMetrics class (below) adds NLG-BLEU - self.add("jga+nlg", AverageMetric(teacher_slots == predicted_slots)) # precision for pred_slot_name, pred_value in predicted_slots.items(): slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value) @@ -86,9 +78,6 @@ def __init__( f1 = F1Metric.compute(guess, labels) self.add_with_prefixes("nlg_bleu", bleu) self.add_with_prefixes("nlg_f1", f1) - if avg_jga_nlg_bleu: - # add one half of Avg(jga,nlg_bleu), SlotMetrics class (above) adds JGA - self.add("jga+nlg", bleu) def add_with_prefixes(self, name, value): self.add(name, value) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index ee06dd6766a..80d9012a1f2 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -7,49 +7,9 @@ Agents (used for dumping data) and Teachers (for training models) related to the TOD conversation setup. -# Usage - -For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. - -For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. - -Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for usage in generating simulations. - -As a convention, agents and teachers that are inheritable are prefixed with "Tod" whereas those that can be used as-is are not. Similarly, classes and functions that do not need to be exposed outside of this file are prefixed with a single underscore ('_'). - -## Why we do this -These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. - -# Description of Agents + Teachers useful for Simulation -## Teachers for training (generative) models - * TodSystemTeacher - * TodUserSimulatorTeacher - -## Agents for Grounding -For goal grounding for the User for simulation: - * TodGoalAgent - * TodSingleGoalAgent - -For (optional) API schema grounding for the System: - * TodApiSchemaAgent (must be used with `TodGoalAgent` only) - * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) - * EmptyApiSchemaAgent - * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. - -## Agents for mocking APIs: - * StandaloneApiAgent - * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher` - -# Agents for dumping data from a ground truth dataset -The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. - * TodApiCallAndSysUttAgent - * TodApiResponseAgent - * TodUserUttAgent - -For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. - -# Other agents -There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. +As a convention, agents and teachers that are inheritable are prefixed with "Tod" +whereas those that can be used as-is are not. Similarly, classes and functions that do +not need to be exposed outside of this file are prefixed with a single underscore ('_') """ from parlai.core.agents import Agent @@ -451,7 +411,7 @@ class StandaloneApiAgent(Agent): Use `TodStandaloneApiTeacher` to train this class. For example for a MultiWoz V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m - parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the + parlai.core.tod.tod_agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the `.pickle` file to use. """ @@ -647,7 +607,6 @@ def custom_evaluation( metrics = SlotMetrics( teacher_slots=teacher_action["slots"], predicted_slots=predicted, - avg_jga_nlg_bleu=True, prefixes=domains, ).report() for key, value in metrics.items(): @@ -669,9 +628,7 @@ def custom_evaluation( domains = ( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) - metrics = NlgMetrics( - guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True - ).report() + metrics = NlgMetrics(guess=resp, labels=labels, prefixes=domains).report() for key, value in metrics.items(): self.metrics.add(key, value) @@ -765,7 +722,7 @@ class TodStandaloneApiTeacher(TodStructuredDataParser, DialogTeacher): Set this as the teacher with `StandaloneApiAgent` as the agent. Ex for a MultiWoz V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m - parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` + parlai.core.tod.tod_agents:StandaloneApiAgent -eps 4 -mf output ``` """ def setup_data(self, fold): diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 2914cb927b3..991d165f278 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -Tests different (more complicated) slot metrics. +Tests teachers + agent implementations, assuming parser to conversations format has +already been done and teachers/agents already created. + +`test_agents.py` includes functions for generating the raw data used in this file as +well as the data parser. """ import unittest @@ -16,6 +20,10 @@ class TestTodAgentsAndTeachersBase(unittest.TestCase): + """ + Base class with convenience functions for setting up agents, dumping text, etc. + """ + def setup_agent_or_teacher(self, class_type, round_opt, opt): full_opts = {**round_opt, **opt} full_opts["datatype"] = "DUMMY" @@ -24,6 +32,9 @@ def setup_agent_or_teacher(self, class_type, round_opt, opt): return class_type(full_opts) def dump_single_utt_per_episode_agent_text(self, class_type, round_opt, opt): + """ + Continuously dumps data from an agent until it's done. + """ agent = self.setup_agent_or_teacher(class_type, round_opt, opt) result = [] while not agent.epoch_done(): @@ -48,14 +59,30 @@ def dump_teacher_text(self, class_type, round_opt, opt): return data def _test_roundDataCorrect(self): + """ + Convenience function that runs on different episode setups. + + Prefix with `_` since not all tests necessarily need this + """ self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__UTTERANCES_ONLY) self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + def _test_roundDataCorrect_helper(self, config): + """ + Implement this in downstream classes to define what is "correct" for a round (Ie + checking serialization data for a given class vs only checking utterances) + """ + raise RuntimeError("Not implemented") + class TestSystemTeacher(TestTodAgentsAndTeachersBase): def test_apiSchemas_with_yesApiSchemas(self): + """ + Tests to make sure that data from first turn is correct when we include API + Schemas. + """ values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__SINGLE_API_CALL, @@ -70,6 +97,10 @@ def test_apiSchemas_with_yesApiSchemas(self): ) def test_apiSchemas_with_noApiSchemas(self): + """ + Tests to make sure that data from first turn is correct when we do not include + API Schemas. + """ values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__SINGLE_API_CALL, @@ -86,7 +117,7 @@ def _test_roundDataCorrect_helper(self, config): for utt in utts: comp.append([utt[0], utt[1]]) comp.append([utt[2], utt[3]]) - # Skip context turn cause we check it above + # Skip grounding turn cause we check it in the other teachers self.assertEqual(episode[1:], comp) def test_roundDataCorrect(self): @@ -95,6 +126,10 @@ def test_roundDataCorrect(self): class TestUserTeacher(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure that all of the User teacher data is correct relative to ground truth, + including grounding turn. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] values = self.dump_teacher_text(test_agents.UserSimulatorTeacher, config, {}) for episode_idx, episode in enumerate(values): @@ -121,6 +156,9 @@ def test_roundDataCorrect(self): class TestGoalAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure goal agent data is correct with (possibly) multiple goals. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -143,6 +181,9 @@ def test_roundDataCorrect(self): class TestApiSchemaAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure api schema information is correct with (possibly) multiple goals. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -165,6 +206,10 @@ def test_roundDataCorrect(self): class TestSingleGoalAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure single goal agent correctly splits conversations with multiple goals + into single goals for the agent. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -187,6 +232,10 @@ def test_roundDataCorrect(self): class TestSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure single api schema agent correctly splits conversations with multiple + goals into single goals for the agent. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -262,6 +311,10 @@ def _test_subsets(self, data_dumps): self.assertEqual(episode, larger[i]) def test_few_shot_subset(self): + """ + Make sure specifying few-shot by n-shot works correctly. + """ + def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, @@ -276,6 +329,10 @@ def helper(n_shot, seed): self.assertNotEqual(data_dumps_seed_zero[-1], data_dumps_seed_three[-1]) def test_percent_shot_lengths_correct(self): + """ + Make sure specifying few-shot by percentages works correctly. + """ + def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, @@ -289,6 +346,10 @@ def helper(percent_shot, correct): helper(0.3, 10) def test_percent_shot_subset(self): + """ + Make sure specifying few-shot by percentages works correctly. + """ + def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index 419210fa18a..60bbb68c5d9 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -4,6 +4,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +""" +File that includes tests for teacher metrics. +""" + import unittest from math import isnan From 5d0197df5276aa7f6213e782328330ebb9ca97a1 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 09:06:14 -0800 Subject: [PATCH 37/57] minor wording change --- parlai/core/tod/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md index d2ffbe33570..22e337f0b0f 100644 --- a/parlai/core/tod/README.md +++ b/parlai/core/tod/README.md @@ -1,6 +1,6 @@ -# Tod Core README +# Task-Oriented Dialog (TOD) Core README -For the quickest getting-to-use of core classes, start with the "Teachers + Agents Usage" section below (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format). +For the quickest getting-to-use of TOD classes, start with the "Teachers + Agents Usage" section below (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format). See `projects/tod_simulator/README` for a higher-level usage-focused README. This document also describes the structure of the contents of this directory. From 76bfa898f53badd0b072efe08ffc269a53ec545a Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 09:34:13 -0800 Subject: [PATCH 38/57] add more documtnation to world tests (following comment on teacher tests) --- tests/tod/test_tod_world_and_script.py | 36 ++++++++++++++++--- tests/tod/test_tod_world_metrics.py | 7 +++- tests/tod/test_tod_world_metrics_in_script.py | 25 ++++++++++++- tests/tod/test_tod_world_script_metrics.py | 9 ++--- 4 files changed, 66 insertions(+), 11 deletions(-) diff --git a/tests/tod/test_tod_world_and_script.py b/tests/tod/test_tod_world_and_script.py index d3f764bc13b..2d8e8944024 100644 --- a/tests/tod/test_tod_world_and_script.py +++ b/tests/tod/test_tod_world_and_script.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Tests tod world, notably for batching. +Tests tod world + script, notably for batching, by comparing saved script logs to the +data that should have been generated. + +Metrics are handled in separate files. """ import copy @@ -114,6 +117,10 @@ def _check_correctness_from_script_logs( class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + """ + Checks that saved data is correct with a single batch. + """ + def _test_roundDataCorrect_helper(self, config): config["batchsize"] = 1 config["max_turns"] = 10 @@ -135,7 +142,7 @@ def _test_max_turn_helper(self, max_turns): config["batchsize"] = 1 config["max_turns"] = max_turns config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 - config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 5 # cause why not agents, opt = self.setup_agents(config) script = TestTodWorldScript(opt) script.agents = agents @@ -150,6 +157,10 @@ def filter_round_utt(utts): class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + """ + Checks saved data is correct with larger batchsizes. + """ + def _test_roundDataCorrect_helper(self, config): config["batchsize"] = 4 config["max_turns"] = 10 @@ -164,6 +175,13 @@ def test_roundDataCorrect(self): class TodWorldTestSingleDumpAgents(TodWorldInScriptTestBase): + """ + Just to be safe, make sure that the agents with "single" versions (ex goal + api + schema) are correctly aligned. + + (Already tested in the agents test file as well, but to be safe.) + """ + def setup_agents(self, added_opts, api_agent, goal_agent): full_opts = self.add_tod_world_opts(added_opts) full_opts["fixed_response"] = "USER: [DONE]" @@ -178,11 +196,11 @@ def setup_agents(self, added_opts, api_agent, goal_agent): ] return agents, full_opts - def test_SingleGoalApiResp_noBatching(self): + def _test_SingleGoalApiResp_helper(self, batchsize, num_episodes): config = {} - config["batchsize"] = 1 + config["batchsize"] = batchsize config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 - config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = num_episodes single_agents, opt = self.setup_agents( config, test_agents.SingleApiSchemaAgent, test_agents.SingleGoalAgent ) @@ -220,6 +238,14 @@ def test_SingleGoalApiResp_noBatching(self): single_idx += 1 + def test_SingleGoalApiResp_helper_singleBatch(self): + self._test_SingleGoalApiResp_helper(1, 2) + self._test_SingleGoalApiResp_helper(1, 5) + + def test_SingleGoalApiResp_helper_multiBatch(self): + self._test_SingleGoalApiResp_helper(4, 8) + self._test_SingleGoalApiResp_helper(4, 11) + if __name__ == "__main__": unittest.main() diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 1399050b3ca..293c677834f 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Tests different (more complicated) slot metrics. +Test world metrics + world metrics handlers against dummy conversations. """ import unittest @@ -83,6 +83,11 @@ class TodMetricsTestHelper: + """ + Given a synthetic intermediate converesation, calculates the metrics for said + conversation. + """ + def __init__(self, e: TodStructuredEpisode): self.m = TodMetrics() self.m.handlers = [ diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index 459458af9d2..19cc0cd20af 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -sTests tod world, notably for batching. +Tests tod world metrics in the full script, *including* making the script properly set +up the agents on its own. + +Use a few of the API Call + goal hit metrics as the metric handlers to test proper +functionality. """ import copy @@ -85,6 +89,9 @@ def _save_outputs(self, opt, world, logger, episode_metrics): class TodMetricsInScriptTests(unittest.TestCase): def test_all_goals_hit_all_success(self): + """ + For a setup where all the goals should be successfully hit, is it? + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP, batchsize=1, num_episodes=1, target_all_goals_hit=1 ) @@ -106,6 +113,9 @@ def test_all_goals_hit_all_success(self): ) def test_all_goals_hit_all_fail(self): + """ + For a setup where all the goals should *not* be successfully hit, do they fail? + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP_BROKEN_USER_SYSTEM, batchsize=1, @@ -139,6 +149,12 @@ def test_all_goals_hit_all_fail(self): ) def test_all_goals_hit_all_success_emptySchema(self): + """ + Check to make sure empty API schema doesn't have any impact on goal (Necessary + cause original, more exhaustive implementation of goal success would separate + between required + optional opts using the schema; make sure it doesn't impact + anything broader) + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP_EMPTY_APISCHEMA, batchsize=1, @@ -172,6 +188,13 @@ def test_all_goals_hit_all_success_emptySchema(self): ) def test_all_goals_hit_all_fail_emptySchema(self): + """ + Make sure empty schema has no impact on goal success. + + (Necessary cause original, more exhaustive implementation of goal success would + separate between required + optional opts using the schema; make sure it doesn't + impact anything broader) + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, batchsize=1, diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 44f060acb84..5152a7bfa23 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -Tests tod world, notably for batching. +Tests tod world metrics in the full script, *without* making the script properly set up +the agents on its own. + +Use a few of the API Call + goal hit metrics as the metric handlers to test proper +functionality. """ import copy @@ -104,9 +108,6 @@ def _check_metrics_correct(self, script, opt): max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] episode_metrics = script.episode_metrics for episode_idx, episode in enumerate(episode_metrics): - # if episode_idx >= max_episodes: - # break - # See how we make broken mock api calls in the test_agents. goal, episode_metric = episode episode_metric = dict_report(episode_metric.report()) self.assertAlmostEqual( From 73c5c7a58f15eac6fe27320ed565c453e2d04e60 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 09:45:00 -0800 Subject: [PATCH 39/57] minor comment update --- parlai/scripts/tod_world_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index c360882e61c..e2d50a1f445 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -6,7 +6,7 @@ """ Base script for running TOD model-model chats. -For example, to extract gold ground truth data from Google SGD, run +For example, to extract gold ground truth data from the holdout version of Google SGD, run ``` python -u -m parlai.scripts.tod_world_script --api-schema-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent --goal-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent --user-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent --system-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent --api-resp-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent -dt valid --num-episodes -1 --episodes-randomization-seed 42 --world-logs gold-valid From 0b846d786e43b38c6dc3a3780dc245ff4e6f6a76 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 12:19:14 -0800 Subject: [PATCH 40/57] make build file less dumb; minor bug in agents.py --- parlai/tasks/multidogo/agents.py | 4 +-- parlai/tasks/multidogo/build.py | 47 +++++++++++++++++++------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/parlai/tasks/multidogo/agents.py b/parlai/tasks/multidogo/agents.py index 1e92e073fb4..eaf9f3cff7a 100644 --- a/parlai/tasks/multidogo/agents.py +++ b/parlai/tasks/multidogo/agents.py @@ -111,10 +111,10 @@ def setup_episodes(self, fold): rounds.append( tod.TodStructuredRound( - user_utt=user_utt, + user_utt="".join(user_utt), api_call_machine=api_call, api_resp_machine=api_resp, - sys_utt=sys_utt, + sys_utt="".join(sys_utt), ) ) goal_calls = copy.deepcopy(all_calls) diff --git a/parlai/tasks/multidogo/build.py b/parlai/tasks/multidogo/build.py index 29c75298548..e8adedc05b8 100644 --- a/parlai/tasks/multidogo/build.py +++ b/parlai/tasks/multidogo/build.py @@ -13,6 +13,7 @@ import os import json import re +import tqdm DEBUG_MISSING_RAW_CONVERSATIONS = False # Unnecessary once Amazon fixes multidogo @@ -51,7 +52,7 @@ PROCESSED = "processed/" -def _preprocess(opt, datapath, datatype): +def _preprocess(opt, datapath, datatype, version): """ MultiDoGo conversations take place between an "agent" and a customer". Labeled customer data is stored in one set of files while the agent data is in another. @@ -67,19 +68,15 @@ def _preprocess(opt, datapath, datatype): intent_type = opt.get("intent_type", TURN_INTENT) for domain in domains: - # to see which domain/datatype combo we've built, use a dummy file to mark - built_file = _get_processed_multidogo_built_file( + out_dir = get_processed_multidogo_folder( datapath, domain, datatype, intent_type ) - if os.path.isfile(built_file): + if build_data.built(out_dir, version): continue print( f" Preprocessing '{domain}' data for '{datatype}' with '{intent_type}' intent labels." ) - out_dir = get_processed_multidogo_folder( - datapath, domain, datatype, intent_type - ) Path(out_dir).mkdir(parents=True, exist_ok=True) # The agent responses for *all* datatypes are in one file. @@ -124,20 +121,13 @@ def _preprocess(opt, datapath, datatype): ) # mark that we've built this combinations - open(built_file, "a").close() + build_data.mark_done(out_dir, version_string=version) def get_processed_multidogo_folder(datapath, domain, datatype, intent_type): return os.path.join(datapath, PROCESSED, domain, intent_type, datatype) -def _get_processed_multidogo_built_file(datapath, domain, datatype, intent_type): - return os.path.join( - get_processed_multidogo_folder(datapath, domain, datatype, intent_type), - ".build", - ) - - # unannotated data is UNANNOTATED_DATA_PROFIX + + '.tsv' # annotated data is ANNOTATED_DATA_PATH + + + '/' + + '.tsv' def _get_unannotated_tsv_data(datapath, domain): @@ -159,6 +149,18 @@ def _get_annotated_tsv_data(datapath, domain, datatype, annotation_type): return csv.reader(open(file_name, "r"), delimiter="\t") +def _get_annotated_tsv_data_size(datapath, domain, datatype, annotation_type): + file_name = os.path.join( + datapath, + RAW_DATA_PREFIX, + RAW_DATA_ANNOTATED_DATA_PATH, + RAW_DATA_INTENT_BY_TYPE_PATH[annotation_type], + domain, + DATATYPE_TO_RAW_DATA_FILE_NAME[datatype], + ) + return sum(1 for line in open(file_name, 'r')) + + def _build_conversation_span_map(unannotated_tsv_object): result = {} # conversationId to (start line, length) map start = 0 @@ -207,7 +209,14 @@ def _aggregate_and_write_conversations( file_idx = start_file_idx intent_tsv = _get_annotated_tsv_data(datapath, domain, datatype, fetch_intent_type) next(intent_tsv) # don't need the header in the first line - for labeled_line in intent_tsv: + print(f"Processing for {domain}, {fetch_intent_type}, {datatype}") + for labeled_line in tqdm.tqdm( + intent_tsv, + total=_get_annotated_tsv_data_size( + datapath, domain, datatype, fetch_intent_type + ) + - 1, + ): conversation_id = labeled_line[0] if conversation_id in skip_ids: continue @@ -318,6 +327,6 @@ def build(opt): # mark the data as built build_data.mark_done(datapath, version_string=version) - # do preprocessing on the data to put it into FBDialogueData format - for fold in ["train", "valid", "test"]: - _preprocess(opt, datapath, fold) + # do preprocessing on the data to put it into FBDialogueData format. There's a lot so check to make sure it's okay + for fold in ["train", "valid", "test"]: + _preprocess(opt, datapath, fold, version) From 65d0a423ea6fb145b16b912f76a036be69f240b2 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 08:50:33 -0800 Subject: [PATCH 41/57] remove + rerun regression test data for multidogo --- parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml | 4 ++-- parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml | 4 ++-- .../multidogo/test/multidogo_UserSimulatorTeacher_train.yml | 4 ++-- .../multidogo/test/multidogo_UserSimulatorTeacher_valid.yml | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml index c86114ed7fa..b3e7a1a7ee4 100644 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml @@ -45,5 +45,5 @@ acts: slots: {} text: 'APIRESP: ' type: 'SYSTEM: ' -num_episodes: 15616 -num_examples: 290050 +num_episodes: 8180 +num_examples: 151916 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml index c1505d9821d..8b584c36e72 100644 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml @@ -43,5 +43,5 @@ acts: slots: {} text: 'APIRESP: ' type: 'SYSTEM: ' -num_episodes: 1590 -num_examples: 29662 +num_episodes: 1148 +num_examples: 21378 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml index f42028106f2..6b558483d2f 100644 --- a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml @@ -47,5 +47,5 @@ acts: here, there is one flight of Jet airways operating on 09/20/2018, The timings are, 6:00 Am to 8:00 Am and it is costing you $170 per head. ' type: 'USER: ' -num_episodes: 15616 -num_examples: 290050 +num_episodes: 8180 +num_examples: 151916 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml index 410ff0590c3..36d321a9557 100644 --- a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml @@ -43,5 +43,5 @@ acts: text: 'SYSTEM: That''s amazing! I recently visited chennai, It''s such a beautiful place! I hope you enjoy the trip! May I know your preferred date please?' type: 'USER: ' -num_episodes: 1590 -num_examples: 29662 +num_episodes: 1148 +num_examples: 21378 From a1aba6a71117abba03dfd026bc19eeb7fa7d56e7 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 11:12:42 -0800 Subject: [PATCH 42/57] see what happens if I bump up the build # (hoping tests work) --- parlai/tasks/taskmaster/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/tasks/taskmaster/build.py b/parlai/tasks/taskmaster/build.py index c60bf3c9862..d931f03be89 100644 --- a/parlai/tasks/taskmaster/build.py +++ b/parlai/tasks/taskmaster/build.py @@ -29,7 +29,7 @@ def build(opt): # get path to data directory dpath = os.path.join(opt['datapath'], 'taskmaster-1') # define version if any - version = "1.01" + version = "1.02" # check if data had been previously built if not build_data.built(dpath, version_string=version): From c9ef957403aa780e3d707df3558c8cda9f6587bd Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 12:14:34 -0800 Subject: [PATCH 43/57] make the multidogo test not take forever --- parlai/tasks/multidogo/agents.py | 2 + parlai/tasks/multidogo/test.py | 4 +- ...acher_multidogo-domains=software_test.yml} | 37 +++++++------- ...acher_multidogo-domains=software_train.yml | 46 +++++++++++++++++ ...acher_multidogo-domains=software_valid.yml | 49 ++++++++++++++++++ .../test/multidogo_SystemTeacher_train.yml | 49 ------------------ .../test/multidogo_SystemTeacher_valid.yml | 47 ----------------- ...eacher_multidogo-domains=software_test.yml | 51 +++++++++++++++++++ ...acher_multidogo-domains=software_train.yml | 49 ++++++++++++++++++ ...acher_multidogo-domains=software_valid.yml | 48 +++++++++++++++++ .../multidogo_UserSimulatorTeacher_test.yml | 46 ----------------- .../multidogo_UserSimulatorTeacher_train.yml | 51 ------------------- .../multidogo_UserSimulatorTeacher_valid.yml | 47 ----------------- 13 files changed, 267 insertions(+), 259 deletions(-) rename parlai/tasks/multidogo/test/{multidogo_SystemTeacher_test.yml => multidogo_SystemTeacher_multidogo-domains=software_test.yml} (53%) create mode 100644 parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml delete mode 100644 parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml delete mode 100644 parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_test.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_train.yml create mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_valid.yml delete mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml delete mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml delete mode 100644 parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/multidogo/agents.py b/parlai/tasks/multidogo/agents.py index eaf9f3cff7a..9e9113dfe8b 100644 --- a/parlai/tasks/multidogo/agents.py +++ b/parlai/tasks/multidogo/agents.py @@ -62,6 +62,8 @@ def __init__(self, opt: Opt, shared=None): def setup_episodes(self, fold): result = [] domains = self.opt.get("multidogo_domains", DOMAINS) + if type(domains) is str: + domains = [domains] intent_type = self.opt.get("intent-type", TURN_INTENT) for _conv_id, domain, conversation in self._iterate_over_conversations( domains, intent_type diff --git a/parlai/tasks/multidogo/test.py b/parlai/tasks/multidogo/test.py index 9f3889a2263..b7f862f7074 100644 --- a/parlai/tasks/multidogo/test.py +++ b/parlai/tasks/multidogo/test.py @@ -8,8 +8,8 @@ class TestSystemTeacher(AutoTeacherTest): - task = "multidogo:SystemTeacher" + task = "multidogo:SystemTeacher:multidogo-domains=software" class TestUserSimulatorTeacher(AutoTeacherTest): - task = "multidogo:UserSimulatorTeacher" + task = "multidogo:UserSimulatorTeacher:multidogo-domains=software" diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_test.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_test.yml similarity index 53% rename from parlai/tasks/multidogo/test/multidogo_SystemTeacher_test.yml rename to parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_test.yml index ffe539d83ad..165410557f5 100644 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_test.yml +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_test.yml @@ -1,5 +1,5 @@ acts: -- - domain: airline +- - domain: software episode_done: false eval_labels: - 'APIS: ' @@ -7,41 +7,44 @@ acts: slots: {} text: 'APIS: ' type: 'APIS: ' -- - domain: airline +- - domain: software episode_done: false eval_labels: - - 'APICALL: api_name = airline' + - 'APICALL: api_name = software' id: Multidogo_SystemTeacher slots: - api_name: airline - text: 'USER: HELLO ROBIN' + api_name: software + text: 'USER: HELLO RORY' type: 'APICALL: ' -- - domain: airline +- - domain: software episode_done: false eval_labels: - - 'SYSTEM: Hello! Good morning. You''ve reached LMT Airways. How may I assist - you today?' + - 'SYSTEM: Hello! Welcome to Prodesk financial department. How may I help you? + + ' id: Multidogo_SystemTeacher slots: {} text: 'APIRESP: ' type: 'SYSTEM: ' -- - domain: airline +- - domain: software episode_done: false eval_labels: - - 'APICALL: api_name = airline' + - 'APICALL: api_name = software' id: Multidogo_SystemTeacher slots: - api_name: airline - text: 'USER: I NEED BOARDING PASS ' + api_name: software + text: 'USER: I WANT REIMBURSE.' type: 'APICALL: ' -- - domain: airline +- - domain: software episode_done: false eval_labels: - - 'SYSTEM: Awesome! I''d be glad to help you with that. May I know your last name - please?' + - 'SYSTEM: Sure! I''ll help you with that. Could you please help me with your + name and the password? + + ' id: Multidogo_SystemTeacher slots: {} text: 'APIRESP: ' type: 'SYSTEM: ' -num_episodes: 2316 -num_examples: 43104 +num_episodes: 155 +num_examples: 2604 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml new file mode 100644 index 00000000000..a5d24a99736 --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml @@ -0,0 +1,46 @@ +acts: +- - domain: software + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: software + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'APICALL: api_name = software' + slots: + api_name: software + text: 'USER: HI' + type: 'APICALL: ' +- - domain: software + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'SYSTEM: Hello! Welcome to Music World! How may I help you today?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: software + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'APICALL: api_name = software' + slots: + api_name: software + text: 'USER: I can set up new recurring orders' + type: 'APICALL: ' +- - domain: software + episode_done: false + id: Multidogo_SystemTeacher + labels: + - 'SYSTEM: Sure! Could you please help me with your company name and one time + password?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 560 +num_examples: 9506 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml new file mode 100644 index 00000000000..2986f1c6672 --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml @@ -0,0 +1,49 @@ +acts: +- - domain: software + episode_done: false + eval_labels: + - 'APIS: ' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: software + episode_done: false + eval_labels: + - 'APICALL: api_name = software' + id: Multidogo_SystemTeacher + slots: + api_name: software + text: 'USER: __SILENCE__' + type: 'APICALL: ' +- - domain: software + episode_done: false + eval_labels: + - 'SYSTEM: Hello! Good morning, You''re connected to finance department of ATA. + How may I help you today?' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: software + episode_done: false + eval_labels: + - 'APICALL: api_name = software' + id: Multidogo_SystemTeacher + slots: + api_name: software + text: 'USER: Good morning. + + I like to report my travel expenses.' + type: 'APICALL: ' +- - domain: software + episode_done: false + eval_labels: + - 'SYSTEM: Definitely! I''ll make a record of your expenses, May I know your name + and password? ' + id: Multidogo_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +num_episodes: 67 +num_examples: 1136 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml deleted file mode 100644 index b3e7a1a7ee4..00000000000 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_train.yml +++ /dev/null @@ -1,49 +0,0 @@ -acts: -- - domain: airline - episode_done: false - id: Multidogo_SystemTeacher - labels: - - 'APIS: ' - slots: {} - text: 'APIS: ' - type: 'APIS: ' -- - domain: airline - episode_done: false - id: Multidogo_SystemTeacher - labels: - - 'APICALL: api_name = airline' - slots: - api_name: airline - text: 'USER: __SILENCE__' - type: 'APICALL: ' -- - domain: airline - episode_done: false - id: Multidogo_SystemTeacher - labels: - - 'SYSTEM: Welcome to High flying customer service! You''re connected to our customer - associate! Good morning! My name is Sam, How may I help you?' - slots: {} - text: 'APIRESP: ' - type: 'SYSTEM: ' -- - domain: airline - episode_done: false - id: Multidogo_SystemTeacher - labels: - - 'APICALL: api_name = airline' - slots: - api_name: airline - text: 'USER: HI,GOOD MORNING - - I WANTS TO BOOK A TICKET FOR FLIGHT' - type: 'APICALL: ' -- - domain: airline - episode_done: false - id: Multidogo_SystemTeacher - labels: - - 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to please - share your details with me. May I know your full name please?' - slots: {} - text: 'APIRESP: ' - type: 'SYSTEM: ' -num_episodes: 8180 -num_examples: 151916 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml deleted file mode 100644 index 8b584c36e72..00000000000 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_valid.yml +++ /dev/null @@ -1,47 +0,0 @@ -acts: -- - domain: airline - episode_done: false - eval_labels: - - 'APIS: ' - id: Multidogo_SystemTeacher - slots: {} - text: 'APIS: ' - type: 'APIS: ' -- - domain: airline - episode_done: false - eval_labels: - - 'APICALL: api_name = airline' - id: Multidogo_SystemTeacher - slots: - api_name: airline - text: 'USER: HI GOOD MORNING' - type: 'APICALL: ' -- - domain: airline - episode_done: false - eval_labels: - - 'SYSTEM: Welcome to High flying customer service! You''re connected to our customer - associate! Good morning! My name is Sam, How may I help you?' - id: Multidogo_SystemTeacher - slots: {} - text: 'APIRESP: ' - type: 'SYSTEM: ' -- - domain: airline - episode_done: false - eval_labels: - - 'APICALL: api_name = airline' - id: Multidogo_SystemTeacher - slots: - api_name: airline - text: 'USER: I WANT TO BOOK A TICKET IN FLIGHT' - type: 'APICALL: ' -- - domain: airline - episode_done: false - eval_labels: - - 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to please - share your details with me. May I know your full name please?' - id: Multidogo_SystemTeacher - slots: {} - text: 'APIRESP: ' - type: 'SYSTEM: ' -num_episodes: 1148 -num_examples: 21378 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_test.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_test.yml new file mode 100644 index 00000000000..710e0cf5f2d --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_test.yml @@ -0,0 +1,51 @@ +acts: +- - domain: software + episode_done: false + eval_labels: + - 'USER: HELLO RORY' + id: Multidogo_UserSimulatorTeacher + text: 'GOAL: api_name = software ; cost = 65 ; name = david ; password = 44685' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: I WANT REIMBURSE.' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Hello! Welcome to Prodesk financial department. How may I help + you? + + ' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: OKAY.MY NAME IS JOHN DAVID AND THE PASSWORD IS 44685' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sure! I''ll help you with that. Could you please help me with your + name and the password? + + ' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: MAY I ATTACH THE BILLS OF MY EXPENSES.' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Thank you for sharing the details. Could you please let me know + the cost of the expenses, so that I can update that for you? + + ' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: OKAY THE TOTAL COST IS $355.' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: You should do that. But, please let me know the cost of the expenses.' + type: 'USER: ' +num_episodes: 155 +num_examples: 1302 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_train.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_train.yml new file mode 100644 index 00000000000..1f59f5bbae7 --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_train.yml @@ -0,0 +1,49 @@ +acts: +- - domain: software + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: HI' + text: 'GOAL: api_name = software ; company_name = yamaha ; password = 56234 ; + quantity = 1' + type: 'USER: ' +- - domain: software + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: I can set up new recurring orders' + slots: {} + text: 'SYSTEM: Hello! Welcome to Music World! How may I help you today?' + type: 'USER: ' +- - domain: software + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: YAMAHA + + 56234' + slots: {} + text: 'SYSTEM: Sure! Could you please help me with your company name and one time + password?' + type: 'USER: ' +- - domain: software + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: YAMAHA + + PSR-E453' + slots: {} + text: 'SYSTEM: Thank you! Could you please let me know the model name of the keyboard + you would like to purchase? ' + type: 'USER: ' +- - domain: software + episode_done: false + id: Multidogo_UserSimulatorTeacher + labels: + - 'USER: 1' + slots: {} + text: 'SYSTEM: Perfect! How many sets would you want to order?' + type: 'USER: ' +num_episodes: 560 +num_examples: 4753 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_valid.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_valid.yml new file mode 100644 index 00000000000..7937a547b45 --- /dev/null +++ b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_multidogo-domains=software_valid.yml @@ -0,0 +1,48 @@ +acts: +- - domain: software + episode_done: false + eval_labels: + - 'USER: __SILENCE__' + id: Multidogo_UserSimulatorTeacher + text: 'GOAL: api_name = software ; cost = 42 ; expense_type = flight ; name = + hagam ; password = 8167' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: Good morning. + + I like to report my travel expenses.' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Hello! Good morning, You''re connected to finance department of + ATA. How may I help you today?' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: Name: Hagam and Password:8167.' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Definitely! I''ll make a record of your expenses, May I know your + name and password? ' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: My total expenses over $197.. Hotel: 65, Food: 90 and Flight:$42.' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Thank you for helping me with your details. Could you please share + the total sum of money that you want to be reimbursed? ' + type: 'USER: ' +- - domain: software + episode_done: false + eval_labels: + - 'USER: I have attached the attachment.' + id: Multidogo_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Okay. Could you please share the snap shot of the receipts here? ' + type: 'USER: ' +num_episodes: 67 +num_examples: 568 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml deleted file mode 100644 index 8c408905467..00000000000 --- a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_test.yml +++ /dev/null @@ -1,46 +0,0 @@ -acts: -- - domain: airline - episode_done: false - eval_labels: - - 'USER: HELLO ROBIN' - id: Multidogo_UserSimulatorTeacher - text: 'GOAL: api_name = airline ; booking_confirmation_number = 523 ; email_address - = gmailcom ; name = mohan' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: I NEED BOARDING PASS ' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: Hello! Good morning. You''ve reached LMT Airways. How may I assist - you today?' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: MOHAN' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: Awesome! I''d be glad to help you with that. May I know your last - name please?' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: CONFIRMATION NUMBER : moh523' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: Alright Mohan! Could you please share the booking confirmation - number?' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: Mohan283@gmail.com' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: Great! May I have your email address please?' - type: 'USER: ' -num_episodes: 2316 -num_examples: 43104 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml deleted file mode 100644 index 6b558483d2f..00000000000 --- a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_train.yml +++ /dev/null @@ -1,51 +0,0 @@ -acts: -- - domain: airline - episode_done: false - id: Multidogo_UserSimulatorTeacher - labels: - - 'USER: __SILENCE__' - text: 'GOAL: api_name = airline ; arrival_city = singapore ; departure_city = - thailand ; email_address = kavigmailcom ; name = kavisri ; number_of_passengers - = five' - type: 'USER: ' -- - domain: airline - episode_done: false - id: Multidogo_UserSimulatorTeacher - labels: - - 'USER: HI,GOOD MORNING - - I WANTS TO BOOK A TICKET FOR FLIGHT' - slots: {} - text: 'SYSTEM: Welcome to High flying customer service! You''re connected to our - customer associate! Good morning! My name is Sam, How may I help you?' - type: 'USER: ' -- - domain: airline - episode_done: false - id: Multidogo_UserSimulatorTeacher - labels: - - 'USER: KAVISRI' - slots: {} - text: 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to - please share your details with me. May I know your full name please?' - type: 'USER: ' -- - domain: airline - episode_done: false - id: Multidogo_UserSimulatorTeacher - labels: - - 'USER: THAILAND AND SINGAPORE' - slots: {} - text: 'SYSTEM: It''s nice meeting you kavisri! Could you please share your departure - and arrival city?' - type: 'USER: ' -- - domain: airline - episode_done: false - id: Multidogo_UserSimulatorTeacher - labels: - - 'USER: OK' - slots: {} - text: 'SYSTEM: Perfect! I hope you enjoy the trip! As I''ve checked with my system - here, there is one flight of Jet airways operating on 09/20/2018, The timings - are, 6:00 Am to 8:00 Am and it is costing you $170 per head. ' - type: 'USER: ' -num_episodes: 8180 -num_examples: 151916 diff --git a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml b/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml deleted file mode 100644 index 36d321a9557..00000000000 --- a/parlai/tasks/multidogo/test/multidogo_UserSimulatorTeacher_valid.yml +++ /dev/null @@ -1,47 +0,0 @@ -acts: -- - domain: airline - episode_done: false - eval_labels: - - 'USER: HI GOOD MORNING' - id: Multidogo_UserSimulatorTeacher - text: 'GOAL: api_name = airline ; arrival_city = chennai ; departure_city = mumbai - ; email_address = gmailcom ; name = viswa ; start_date = 92018' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: I WANT TO BOOK A TICKET IN FLIGHT' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: Welcome to High flying customer service! You''re connected to our - customer associate! Good morning! My name is Sam, How may I help you?' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: VISWA ' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: Absolutely!! I''d be happy to book your tickets, I''d request to - please share your details with me. May I know your full name please?' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: MUMBAI TO CHENNAI' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: Great! It''s nice meeting you Viswa! Could you please share your - departure and arrival city?' - type: 'USER: ' -- - domain: airline - episode_done: false - eval_labels: - - 'USER: 09/20/2018' - id: Multidogo_UserSimulatorTeacher - slots: {} - text: 'SYSTEM: That''s amazing! I recently visited chennai, It''s such a beautiful - place! I hope you enjoy the trip! May I know your preferred date please?' - type: 'USER: ' -num_episodes: 1148 -num_examples: 21378 From 7ab9d70476be5831aca05322ab1d2954aac0f698 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 12:23:53 -0800 Subject: [PATCH 44/57] update to respect actual count of episodes (I think this might have implications for mutators too?) --- parlai/core/teachers.py | 2 ++ parlai/core/tod/tod_agents.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/parlai/core/teachers.py b/parlai/core/teachers.py index 2b3c1d71741..1f8829324db 100644 --- a/parlai/core/teachers.py +++ b/parlai/core/teachers.py @@ -726,6 +726,8 @@ def num_episodes(self) -> int: """ Return the number of episodes in the data. """ + if hasattr(self, "_num_episodes_cache"): + return self._num_episodes_cache try: return self.data.num_episodes() except AttributeError: diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index 80d9012a1f2..0dd26cd270e 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -590,6 +590,11 @@ def add_cmdline_args( ) return parser + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self._num_examples_cache = sum([len(x.rounds) * 2 for x in self.episodes]) + self._num_episodes_cache = len(self.episodes) + def custom_evaluation( self, teacher_action: Message, labels, model_response: Message ): @@ -671,6 +676,12 @@ class TodUserSimulatorTeacher(TodStructuredDataParser, DialogTeacher): Utterance->User Utterance` for all subsequent turns. """ + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + # Manually set number of examples + number of episodes + self._num_examples_cache = sum([len(x.rounds) for x in self.episodes]) + self._num_episodes_cache = len(self.episodes) + def setup_data(self, fold): for episode in self.generate_episodes(): if len(episode.rounds) < 1: From 94661448616a94dcfbe5f659eeaebbd7f6000c34 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 22:30:53 -0800 Subject: [PATCH 45/57] regen after changing tod teacher logic to respect episode/examples length --- .../google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml | 2 +- .../google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml | 2 +- .../google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_test.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_train.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_valid.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml index dad04075246..8ea369e9cad 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml @@ -48,4 +48,4 @@ acts: Bar in Corte Madera at 12 pm for 2 on March 8th.' type: 'USER: ' num_episodes: 4201 -num_examples: 97197 +num_examples: 46498 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml index fca91a3b09f..d4ea92d7d1b 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml @@ -46,4 +46,4 @@ acts: San Pedro Street.' type: 'USER: ' num_episodes: 16142 -num_examples: 378390 +num_examples: 181124 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml index e9a511d9519..59d81a949c9 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml @@ -43,4 +43,4 @@ acts: options.' type: 'USER: ' num_episodes: 2482 -num_examples: 56172 +num_examples: 26845 diff --git a/parlai/tasks/google_sgd/test/google_sgd_test.yml b/parlai/tasks/google_sgd/test/google_sgd_test.yml index 4d8e4e2578e..d55849a814b 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_test.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_test.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 4201 -num_examples: 97197 +num_examples: 92996 diff --git a/parlai/tasks/google_sgd/test/google_sgd_train.yml b/parlai/tasks/google_sgd/test/google_sgd_train.yml index 596025f82da..584b6a56561 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_train.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_train.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 16142 -num_examples: 378390 +num_examples: 362248 diff --git a/parlai/tasks/google_sgd/test/google_sgd_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_valid.yml index 2cd09f89ce7..3d58a2b919d 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_valid.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_valid.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 2482 -num_examples: 56172 +num_examples: 53690 From 1392d99676d69f8c72e8874ad4a0c1bd464b54a8 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 22:33:11 -0800 Subject: [PATCH 46/57] regen after changing tod teacher logic to respect episode/examples length --- ...oogle_sgd_simulation_splits_InDomainSystemTeacher_test.yml | 2 +- ...ogle_sgd_simulation_splits_InDomainSystemTeacher_train.yml | 2 +- ...ogle_sgd_simulation_splits_InDomainSystemTeacher_valid.yml | 2 +- ...gd_simulation_splits_InDomainUserSimulatorTeacher_test.yml | 2 +- ...d_simulation_splits_InDomainUserSimulatorTeacher_train.yml | 2 +- ...d_simulation_splits_InDomainUserSimulatorTeacher_valid.yml | 2 +- ...ogle_sgd_simulation_splits_OutDomainSystemTeacher_test.yml | 4 ++-- ...gle_sgd_simulation_splits_OutDomainSystemTeacher_train.yml | 4 ++-- ...gle_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml | 4 ++-- ...d_simulation_splits_OutDomainUserSimulatorTeacher_test.yml | 4 ++-- ..._simulation_splits_OutDomainUserSimulatorTeacher_train.yml | 4 ++-- ..._simulation_splits_OutDomainUserSimulatorTeacher_valid.yml | 4 ++-- 12 files changed, 18 insertions(+), 18 deletions(-) diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml index d61223204be..0edd1c96ed7 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 3132 -num_examples: 67286 +num_examples: 64154 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml index 08800de222d..8177b877b5a 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 13888 -num_examples: 320622 +num_examples: 306734 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml index 92f565fea21..444c6b1068d 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1966 -num_examples: 42808 +num_examples: 40842 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml index 2934f561f8c..de663c3c02a 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_test.yml @@ -48,4 +48,4 @@ acts: Bar in Corte Madera at 12 pm for 2 on March 8th.' type: 'USER: ' num_episodes: 3132 -num_examples: 67286 +num_examples: 32077 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml index 37257b324ef..4ad5df8583b 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_train.yml @@ -46,4 +46,4 @@ acts: San Pedro Street.' type: 'USER: ' num_episodes: 13888 -num_examples: 320622 +num_examples: 153367 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml index c6d888e6f57..cf7cbb3f0be 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainUserSimulatorTeacher_valid.yml @@ -43,4 +43,4 @@ acts: options.' type: 'USER: ' num_episodes: 1966 -num_examples: 42808 +num_examples: 20421 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml index 4d9686d40f2..05f1a2c83e5 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml @@ -50,5 +50,5 @@ acts: Landmark ; free_entry = False ; good_for_kids = True ; location = London ; phone_number = 20 7071 5029' type: 'SYSTEM: ' -num_episodes: 3132 -num_examples: 67286 +num_episodes: 768 +num_examples: 19362 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml index 2792ba26440..85ed0d88000 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml @@ -59,5 +59,5 @@ acts: ; end_date = 2019-03-07 ; pickup_location = Downtown Station ; pickup_time = 11:00 ; price_per_day = 38.00 ; start_date = 2019-03-04' type: 'SYSTEM: ' -num_episodes: 13888 -num_examples: 320622 +num_episodes: 2303 +num_examples: 58286 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml index be07f99ea7e..cd166a17652 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml @@ -39,5 +39,5 @@ acts: slots: {} text: 'APIRESP: ' type: 'SYSTEM: ' -num_episodes: 1966 -num_examples: 42808 +num_episodes: 768 +num_examples: 19556 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml index 9ba470ac70d..2bf74be76b0 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_test.yml @@ -48,5 +48,5 @@ acts: slots: {} text: 'SYSTEM: Alright, will you be picking it up in London?' type: 'USER: ' -num_episodes: 3132 -num_examples: 67286 +num_episodes: 768 +num_examples: 9681 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml index 289b95c3333..500e6bfba3e 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_train.yml @@ -45,5 +45,5 @@ acts: Station ; pickup_time = 16:30 ; type = Standard | api_name = FindAttractions ; location = Fresno' type: 'USER: ' -num_episodes: 13888 -num_examples: 320622 +num_episodes: 2303 +num_examples: 29143 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml index d63e986ddb5..995be3d551a 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainUserSimulatorTeacher_valid.yml @@ -43,5 +43,5 @@ acts: text: 'SYSTEM: I have 10 properties that might suit, including Apricot pit apartments 400 east remington drive. The listing price is $3,650,000' type: 'USER: ' -num_episodes: 1966 -num_examples: 42808 +num_episodes: 768 +num_examples: 9778 From 77dccb7c887b82cad91e73d757c0a4fd9642346d Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 22:34:35 -0800 Subject: [PATCH 47/57] regen after changing tod teacher logic to respect episode/examples length --- parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml | 2 +- parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml | 2 +- parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml | 2 +- parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml | 2 +- .../tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml | 2 +- .../tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml index 9fd946c1b75..17375135a9a 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml @@ -68,4 +68,4 @@ acts: text: 'APIRESP: api_name = taxi ; car_type = UberX ; closing = have a safe trip' type: 'SYSTEM: ' num_episodes: 1011 -num_examples: 10705 +num_examples: 9694 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml index fd643a0d2cb..52d60c1740b 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml @@ -53,4 +53,4 @@ acts: text: 'APIRESP: api_name = taxi ; car_type = uberX ; cost = $37-48' type: 'SYSTEM: ' num_episodes: 8068 -num_examples: 84524 +num_examples: 76456 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml index aec69b50c10..d58f9879cbd 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml @@ -52,4 +52,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1008 -num_examples: 10572 +num_examples: 9564 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml index 47d38c225e3..a03425c3f37 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_test.yml @@ -47,4 +47,4 @@ acts: Can I help you with something else?' type: 'USER: ' num_episodes: 1011 -num_examples: 10705 +num_examples: 4847 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml index 8046f82b47e..30db1448a08 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_train.yml @@ -45,4 +45,4 @@ acts: state = California ; theater = any theater' type: 'USER: ' num_episodes: 8068 -num_examples: 84524 +num_examples: 38228 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml index 32ff24e5502..ebbb75d407b 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_UserSimulatorTeacher_valid.yml @@ -45,4 +45,4 @@ acts: = taxi ; numberofpeople = 4 | api_name = taxi ; car_type = UberX' type: 'USER: ' num_episodes: 1008 -num_examples: 10572 +num_examples: 4782 From 98aa5f719e0ff1f69611720c9125173db64e21b8 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 22:39:56 -0800 Subject: [PATCH 48/57] regen after changing tod teacher logic to respect episode/examples length --- .../test/multiwoz_v22_UserSimulatorTeacher_test.yml | 2 +- .../test/multiwoz_v22_UserSimulatorTeacher_train.yml | 2 +- .../test/multiwoz_v22_UserSimulatorTeacher_valid.yml | 2 +- parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml | 2 +- parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml | 2 +- parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml index 29fee769b4e..41ace2b678b 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_test.yml @@ -47,4 +47,4 @@ acts: would you like on it?' type: 'USER: ' num_episodes: 1000 -num_examples: 17744 +num_examples: 8372 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml index 2aca735580f..e59321f4ed5 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_train.yml @@ -48,4 +48,4 @@ acts: text: 'SYSTEM: Sure, when would you like that reservation?' type: 'USER: ' num_episodes: 7913 -num_examples: 133963 +num_examples: 63025 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml index 802f5e89ebb..1853cab50a3 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_UserSimulatorTeacher_valid.yml @@ -47,4 +47,4 @@ acts: text: 'SYSTEM: I have train TR1840 leaving at 16:36 is that okay?' type: 'USER: ' num_episodes: 999 -num_examples: 17731 +num_examples: 8366 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml index 09c0e97c40d..3b44eb8ecf9 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml @@ -63,4 +63,4 @@ acts: "leaveat": "05:16", "price": "17.60 pounds", "trainid": "TR9020"}]' type: 'SYSTEM: ' num_episodes: 1000 -num_examples: 17744 +num_examples: 16744 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml index 4b0217a42bd..0d7b74f5d23 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml @@ -70,4 +70,4 @@ acts: "restaurant", "signature": NaN}]' type: 'SYSTEM: ' num_episodes: 7913 -num_examples: 133963 +num_examples: 126050 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml index 5228dce4c2f..376bfda659c 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml @@ -56,4 +56,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 999 -num_examples: 17731 +num_examples: 16732 From 8291321ba2b4c0029eed4ea696b1c22e1460872f Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 22:44:44 -0800 Subject: [PATCH 49/57] regen after changing tod teacher logic to respect episode/examples length --- parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml | 2 +- parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml | 2 +- parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml | 2 +- .../taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml | 2 +- .../taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml | 2 +- .../taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml index 998350f1020..974c39f1d83 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1326 -num_examples: 33198 +num_examples: 31872 diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml index 417c226d0e4..923c0549f92 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml @@ -46,4 +46,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 10555 -num_examples: 262021 +num_examples: 251466 diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml index a541abf7d1f..a58c16ba4bd 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml @@ -47,4 +47,4 @@ acts: text: 'APIRESP: api_name = coffee_ordering ; location.store = hilton knoxville' type: 'SYSTEM: ' num_episodes: 1321 -num_examples: 32641 +num_examples: 31320 diff --git a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml index fcafb66b9f8..a32b210fccb 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_test.yml @@ -54,4 +54,4 @@ acts: would you like to book this right now?' type: 'USER: ' num_episodes: 1326 -num_examples: 33198 +num_examples: 15936 diff --git a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml index fc424778230..5317354dea8 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_train.yml @@ -43,4 +43,4 @@ acts: up?' type: 'USER: ' num_episodes: 10555 -num_examples: 262021 +num_examples: 125733 diff --git a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml index 3b8f69cb4c2..d4e36117612 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_UserSimulatorTeacher_valid.yml @@ -43,4 +43,4 @@ acts: text: 'SYSTEM: sure. would you like whipped cream?' type: 'USER: ' num_episodes: 1321 -num_examples: 32641 +num_examples: 15660 From 1d3d0c6f8e7b17fee7c9a04fb65a22aa2873ca66 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 22:46:28 -0800 Subject: [PATCH 50/57] regen after changing tod teacher logic to respect episode/examples length --- .../taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml | 2 +- .../taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml | 2 +- .../taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml | 2 +- parlai/tasks/taskmaster2/test/taskmaster2_test.yml | 2 +- parlai/tasks/taskmaster2/test/taskmaster2_train.yml | 2 +- parlai/tasks/taskmaster2/test/taskmaster2_valid.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml index 318b3b416a8..cabdc76d438 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_test.yml @@ -43,4 +43,4 @@ acts: to 0.' type: 'USER: ' num_episodes: 1734 -num_examples: 36584 +num_examples: 17425 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml index 16cd94caae9..d4ad22a81c3 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_train.yml @@ -40,4 +40,4 @@ acts: text: 'SYSTEM: Starting point guard is Gary Harris.' type: 'USER: ' num_episodes: 13840 -num_examples: 291032 +num_examples: 138596 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml index 97dad21316a..95ef7136d19 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_UserSimulatorTeacher_valid.yml @@ -42,4 +42,4 @@ acts: text: 'SYSTEM: No, they''re not scheduled to play today.' type: 'USER: ' num_episodes: 1730 -num_examples: 36404 +num_examples: 17337 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_test.yml b/parlai/tasks/taskmaster2/test/taskmaster2_test.yml index eae250af098..7299da10eee 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_test.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_test.yml @@ -52,4 +52,4 @@ acts: text: 'APIRESP: api_name = nfl ; name.team = Denver Broncos' type: 'SYSTEM: ' num_episodes: 1734 -num_examples: 36584 +num_examples: 34850 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_train.yml b/parlai/tasks/taskmaster2/test/taskmaster2_train.yml index 5676f053795..11346f17023 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_train.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_train.yml @@ -45,4 +45,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 13840 -num_examples: 291032 +num_examples: 277192 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml b/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml index 3a1c36b2bdb..1571b309003 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml @@ -40,4 +40,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1730 -num_examples: 36404 +num_examples: 34674 From acd6ffe40aa8c1da18a310a43e81214ac4d48058 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 09:49:14 -0800 Subject: [PATCH 51/57] not sure why this comment keeps not being merged correctly ugh... --- parlai/core/tod/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md index 11abaa1ad92..f29e5d2778a 100644 --- a/parlai/core/tod/README.md +++ b/parlai/core/tod/README.md @@ -10,7 +10,7 @@ As a convention, files referenced externally to this directory are prefixed with # Teachers + Agents Usage -tl;dr Extend `TodStructuredDataParser` for your particular dataset and implement `generate_episodes()` that converts the dataset into a list of episodes (`List[TodStructuredEpisode]`). Use multiple inheritence to generate teachers for training models. See files like `parlai/tasks/multiwoz_v22/agents.py` for an example. +tl;dr Extend `TodStructuredDataParser` for your particular dataset and implement `setup_episodes()` that converts the dataset into a list of episodes (`List[TodStructuredEpisode]`). Use multiple inheritence to generate teachers for training models. See files like `parlai/tasks/multiwoz_v22/agents.py` for an example. See `tod_agents.py` for the classes. From 0f49cb58610064eb7dc25b082f864c51b4de6c5f Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 12:52:03 -0800 Subject: [PATCH 52/57] noticed a different in episode lengths between old version of this data and new one; relized I was missing a +1 in the episode length count --- parlai/core/tod/tod_agents.py | 2 +- parlai/tasks/google_sgd/test/google_sgd_test.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_train.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_valid.yml | 2 +- .../google_sgd_simulation_splits_InDomainSystemTeacher_test.yml | 2 +- ...google_sgd_simulation_splits_InDomainSystemTeacher_train.yml | 2 +- ...google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml | 2 +- ...google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml | 2 +- ...oogle_sgd_simulation_splits_OutDomainSystemTeacher_train.yml | 2 +- ...oogle_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index 574465aaa09..dcb6ee260aa 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -644,7 +644,7 @@ def add_cmdline_args( def __init__(self, opt, shared=None): super().__init__(opt, shared) - self._num_examples_cache = sum([len(x.rounds) * 2 for x in self.episodes]) + self._num_examples_cache = sum([len(x.rounds) * 2 + 1 for x in self.episodes]) self._num_episodes_cache = len(self.episodes) def custom_evaluation( diff --git a/parlai/tasks/google_sgd/test/google_sgd_test.yml b/parlai/tasks/google_sgd/test/google_sgd_test.yml index d55849a814b..4d8e4e2578e 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_test.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_test.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 4201 -num_examples: 92996 +num_examples: 97197 diff --git a/parlai/tasks/google_sgd/test/google_sgd_train.yml b/parlai/tasks/google_sgd/test/google_sgd_train.yml index 584b6a56561..596025f82da 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_train.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_train.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 16142 -num_examples: 362248 +num_examples: 378390 diff --git a/parlai/tasks/google_sgd/test/google_sgd_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_valid.yml index 3d58a2b919d..2cd09f89ce7 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_valid.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_valid.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 2482 -num_examples: 53690 +num_examples: 56172 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml index 0edd1c96ed7..d61223204be 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_test.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 3132 -num_examples: 64154 +num_examples: 67286 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml index 8177b877b5a..08800de222d 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_train.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 13888 -num_examples: 306734 +num_examples: 320622 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml index 444c6b1068d..92f565fea21 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_InDomainSystemTeacher_valid.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1966 -num_examples: 40842 +num_examples: 42808 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml index 05f1a2c83e5..cb9af418dc6 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_test.yml @@ -51,4 +51,4 @@ acts: = 20 7071 5029' type: 'SYSTEM: ' num_episodes: 768 -num_examples: 19362 +num_examples: 20130 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml index 85ed0d88000..c7be6ba21f8 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_train.yml @@ -60,4 +60,4 @@ acts: 11:00 ; price_per_day = 38.00 ; start_date = 2019-03-04' type: 'SYSTEM: ' num_episodes: 2303 -num_examples: 58286 +num_examples: 60589 diff --git a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml index cd166a17652..e140b5e23ee 100644 --- a/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml +++ b/parlai/tasks/google_sgd_simulation_splits/test/google_sgd_simulation_splits_OutDomainSystemTeacher_valid.yml @@ -40,4 +40,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 768 -num_examples: 19556 +num_examples: 20324 From adff9499ee71e7ea5ccc721557a4a2a0856f05d0 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 13:36:11 -0800 Subject: [PATCH 53/57] regen after changing tod teacher logic to respect episode/examples length --- parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml | 2 +- parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml | 2 +- parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml index 17375135a9a..9fd946c1b75 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_test.yml @@ -68,4 +68,4 @@ acts: text: 'APIRESP: api_name = taxi ; car_type = UberX ; closing = have a safe trip' type: 'SYSTEM: ' num_episodes: 1011 -num_examples: 9694 +num_examples: 10705 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml index 52d60c1740b..fd643a0d2cb 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_train.yml @@ -53,4 +53,4 @@ acts: text: 'APIRESP: api_name = taxi ; car_type = uberX ; cost = $37-48' type: 'SYSTEM: ' num_episodes: 8068 -num_examples: 76456 +num_examples: 84524 diff --git a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml index d58f9879cbd..aec69b50c10 100644 --- a/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml +++ b/parlai/tasks/msr_e2e/test/msr_e2e_SystemTeacher_valid.yml @@ -52,4 +52,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1008 -num_examples: 9564 +num_examples: 10572 From d724cd89d472f79b1ac7859d308232009385b789 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 13:44:08 -0800 Subject: [PATCH 54/57] regen after changing tod teacher logic to respect episode/examples length --- .../multidogo_SystemTeacher_multidogo-domains=software_test.yml | 2 +- ...multidogo_SystemTeacher_multidogo-domains=software_train.yml | 2 +- ...multidogo_SystemTeacher_multidogo-domains=software_valid.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_test.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_test.yml index 165410557f5..608099ae2f3 100644 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_test.yml +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_test.yml @@ -47,4 +47,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 155 -num_examples: 2604 +num_examples: 2759 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml index a5d24a99736..8661c3c467b 100644 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_train.yml @@ -43,4 +43,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 560 -num_examples: 9506 +num_examples: 10066 diff --git a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml index 2986f1c6672..4f350df94fd 100644 --- a/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml +++ b/parlai/tasks/multidogo/test/multidogo_SystemTeacher_multidogo-domains=software_valid.yml @@ -46,4 +46,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 67 -num_examples: 1136 +num_examples: 1203 From b4e5c1ffdd5c29d6fd83de476bc457e52dda4d24 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 13:48:55 -0800 Subject: [PATCH 55/57] regen after changing tod teacher logic to respect episode/examples length --- parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml | 2 +- parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml | 2 +- parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml index 3b44eb8ecf9..09c0e97c40d 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_test.yml @@ -63,4 +63,4 @@ acts: "leaveat": "05:16", "price": "17.60 pounds", "trainid": "TR9020"}]' type: 'SYSTEM: ' num_episodes: 1000 -num_examples: 16744 +num_examples: 17744 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml index 0d7b74f5d23..4b0217a42bd 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_train.yml @@ -70,4 +70,4 @@ acts: "restaurant", "signature": NaN}]' type: 'SYSTEM: ' num_episodes: 7913 -num_examples: 126050 +num_examples: 133963 diff --git a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml index 376bfda659c..5228dce4c2f 100644 --- a/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml +++ b/parlai/tasks/multiwoz_v22/test/multiwoz_v22_valid.yml @@ -56,4 +56,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 999 -num_examples: 16732 +num_examples: 17731 From 21a05d2db3c434d6a6adecc1aa37f1daabf8aeb7 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 13:52:11 -0800 Subject: [PATCH 56/57] regen after changing tod teacher logic to respect episode/examples length --- parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml | 2 +- parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml | 2 +- parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml index 974c39f1d83..998350f1020 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_test.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1326 -num_examples: 31872 +num_examples: 33198 diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml index 923c0549f92..417c226d0e4 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_train.yml @@ -46,4 +46,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 10555 -num_examples: 251466 +num_examples: 262021 diff --git a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml index a58c16ba4bd..a541abf7d1f 100644 --- a/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml +++ b/parlai/tasks/taskmaster/test/taskmaster_SystemTeacher_valid.yml @@ -47,4 +47,4 @@ acts: text: 'APIRESP: api_name = coffee_ordering ; location.store = hilton knoxville' type: 'SYSTEM: ' num_episodes: 1321 -num_examples: 31320 +num_examples: 32641 From e9ea6ac22ed44d63fdb6914bc4594fbe8af14bfd Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 13:53:44 -0800 Subject: [PATCH 57/57] regen after changing tod teacher logic to respect episode/examples length --- parlai/tasks/taskmaster2/test/taskmaster2_test.yml | 2 +- parlai/tasks/taskmaster2/test/taskmaster2_train.yml | 2 +- parlai/tasks/taskmaster2/test/taskmaster2_valid.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_test.yml b/parlai/tasks/taskmaster2/test/taskmaster2_test.yml index 7299da10eee..eae250af098 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_test.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_test.yml @@ -52,4 +52,4 @@ acts: text: 'APIRESP: api_name = nfl ; name.team = Denver Broncos' type: 'SYSTEM: ' num_episodes: 1734 -num_examples: 34850 +num_examples: 36584 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_train.yml b/parlai/tasks/taskmaster2/test/taskmaster2_train.yml index 11346f17023..5676f053795 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_train.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_train.yml @@ -45,4 +45,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 13840 -num_examples: 277192 +num_examples: 291032 diff --git a/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml b/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml index 1571b309003..3a1c36b2bdb 100644 --- a/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml +++ b/parlai/tasks/taskmaster2/test/taskmaster2_valid.yml @@ -40,4 +40,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 1730 -num_examples: 34674 +num_examples: 36404