From e365e48c16e868efbdd2cbc4ad3568a3ed2021b5 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 15:50:00 -0800 Subject: [PATCH 01/23] [TOD] Core converesation structure, serialization, const tokens --- parlai/core/tod/tod_core.py | 227 ++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 parlai/core/tod/tod_core.py diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py new file mode 100644 index 00000000000..76ee8005a74 --- /dev/null +++ b/parlai/core/tod/tod_core.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) enums and base classes. + +This file defines standard tokens, classes for round and conversation structure, and a serialization class to aid in converting between these. + +See `tod_agents.py` for usage of these classes to generate training data and `tod_world_script.py` for usage of these classes in simulated conversations. +""" +from enum import Enum +from typing import List, Dict +from dataclasses import dataclass, field +from collections.abc import Iterable +from parlai.utils.misc import warn_once + +STANDARD_CALL = "APICALL: " +STANDARD_RESP = "APIRESP: " +STANDARD_SYSTEM_UTTERANCE = "SYSTEM: " +STANDARD_USER_UTTERANCE = "USER: " + +STANDARD_GOAL = "GOAL: " +STANDARD_API_SCHEMAS = "APIS: " + +STANDARD_API_NAME_SLOT = "api_name" +STANDARD_REQUIRED_KEY = "reqArg" +STANDARD_OPTIONAL_KEY = "optArg" +STANDARD_DONE = "[DONE]" + +CONST_SILENCE = "__SILENCE__" + + +class TodAgentType(str, Enum): + USER_UTT_AGENT = "user_utt_model" + API_CALL_AGENT = "api_call_model" + API_RESP_AGENT = "api_resp_model" + SYSTEM_UTT_AGENT = "system_utt_model" + API_SCHEMA_GROUNDING_AGENT = "api_schema_grounding_model" + GOAL_GROUNDING_AGENT = "goal_grounding_model" + + +TOD_AGENT_TYPE_TO_PREFIX = { + TodAgentType.USER_UTT_AGENT: STANDARD_USER_UTTERANCE, + TodAgentType.API_CALL_AGENT: STANDARD_CALL, + TodAgentType.API_RESP_AGENT: STANDARD_RESP, + TodAgentType.SYSTEM_UTT_AGENT: STANDARD_SYSTEM_UTTERANCE, + TodAgentType.API_SCHEMA_GROUNDING_AGENT: STANDARD_API_SCHEMAS, + TodAgentType.GOAL_GROUNDING_AGENT: STANDARD_GOAL, +} + + +@dataclass +class TodStructuredRound: + """ + Dataclass for rounds. + """ + + # Variables set by those using this class + user_utt: str = "" + api_call_machine: Dict = field( + default_factory=dict + ) # Hashmap of slot keys and slot values. Note that STANDARD_API_NAME_SLOT (`api_name`) is expected to be one of the keys here when this is nonempty; simulation metrics wonky without + api_resp_machine: Dict = field(default_factory=dict) + sys_utt: str = "" + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_call_utt: str = field(init=False) + api_resp_utt: str = field(init=False) + + def __post_init__(self): + self.api_call_utt = SerializationHelpers.api_dict_to_str(self.api_call_machine) + self.api_resp_utt = SerializationHelpers.api_dict_to_str(self.api_resp_machine) + if ( + len(self.api_call_machine) > 0 + and STANDARD_API_NAME_SLOT not in self.api_call_machine + ): + warn_once( + f"{STANDARD_API_NAME_SLOT} missing when API Call present. This may cause issues for simulation metrics." + ) + + +@dataclass +class TodStructuredEpisode: + """ + Dataclass for episode-level data. + """ + + # Variables set by those using this class + delex: bool = False # Set to true and this class will handle delexicalizing call + response utterances based on API calls and responses exposed to this class. + domain: str = "" # self-explanatory + api_schemas_machine: List[Dict[str, List]] = field( + default_factory=list + ) # Expected to be a List of Dicts with the API name, required arguments, and optional arguments (specified by consts at the top of this file) as keys + goal_calls_machine: List[Dict[str, str]] = field( + default_factory=list + ) # Machine-formatted API calls + rounds: List[TodStructuredRound] = field(default_factory=list) # self explanatory + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_schemas_utt: str = field(init=False) + goal_calls_utt: str = field(init=False) + + def __post_init__(self): + self.api_schemas_utt = SerializationHelpers.list_of_maps_to_str( + self.api_schemas_machine + ) + self.goal_calls_machine = [ + call for call in self.goal_calls_machine if len(call) > 0 + ] + self.goal_calls_utt = SerializationHelpers.list_of_maps_to_str( + self.goal_calls_machine + ) + # Add a done turn at the end + self.rounds.append(TodStructuredRound(user_utt=STANDARD_DONE)) + if self.delex: + accum_slots = ( + {} + ) # separate since some slot values change as we go. Use this for delex first + cum_slots = self.get_all_slots() + for r in self.rounds: + accum_slots.update(r.api_call_machine) + accum_slots.update(r.api_resp_machine) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, accum_slots) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, cum_slots) + + def get_all_slots(self): + result = {} + for r in self.rounds: + result.update(r.api_call_machine) + result.update(r.api_resp_machine) + return result + + +class SerializationHelpers: + @classmethod + def delex(cls, text, slots): + delex = text + for slot, value in slots.items(): + if isinstance(value, str): + delex = delex.replace(value, f"[{slot}]") + else: + for v in value: + delex = delex.replace(v, f"[{slot}]") + return delex + + @classmethod + def inner_list_join(cls, values): + if isinstance(values, str): + return values + return ", ".join(sorted([v.strip() for v in values])) + + @classmethod + def inner_list_split(cls, s): + return s.split(", ") + + @classmethod + def maybe_inner_list_join(cls, values): + if isinstance(values, str) or isinstance(values, int): + return values + elif isinstance(values, Iterable): + return SerializationHelpers.inner_list_join(values) + else: + raise RuntimeError("invalid type of argument for maybe_inner_list_join") + + @classmethod + def api_dict_to_str(cls, apidict): + """ + Used for API Calls and Responses -> Utterance. + """ + return " ; ".join( + f"{k} = {SerializationHelpers.maybe_inner_list_join(v)}" + for k, v in sorted(apidict.items()) + ) + + @classmethod + def str_to_api_dict(cls, string): + """ + Used for API Call and Response Utterances -> Dict. + """ + slot_strs = string.split(" ; ") + result = {} + for slot_str in slot_strs: + if " = " not in slot_str: + continue + name, value = slot_str.split(" = ", 1) + name = name.strip() + value = value.strip() + result[name] = value + return result + + @classmethod + def outer_list_join(cls, s): + return " | ".join(s) + + @classmethod + def outer_list_split(cls, s): + return s.split(" | ") + + @classmethod + def str_to_list_of_maps(cls, s): + return [ + SerializationHelpers.str_to_api_dict(x) + for x in SerializationHelpers.outer_list_split(s) + ] + + @classmethod + def list_of_maps_to_str(cls, list_of_maps): + return SerializationHelpers.outer_list_join( + [SerializationHelpers.api_dict_to_str(m) for m in list_of_maps] + ) + + @classmethod + def str_to_goals(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) + + @classmethod + def str_to_api_schemas(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) From c939174b44a8f4954a53e7327b7a66add28f7fce Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 20:21:58 -0800 Subject: [PATCH 02/23] [Tod] Agents, teacher metrics, and tests for these See documentation block in `tod_agents.py` --- conftest.py | 1 + parlai/core/tod/teacher_metrics.py | 159 ++++ parlai/core/tod/tod_agents.py | 794 ++++++++++++++++++ parlai/core/tod/tod_test_utils/test_agents.py | 216 +++++ pytest.ini | 1 + tests/tod/test_tod_agents_and_teachers.py | 327 ++++++++ tests/tod/test_tod_teacher_metrics.py | 74 ++ 7 files changed, 1572 insertions(+) create mode 100644 parlai/core/tod/teacher_metrics.py create mode 100644 parlai/core/tod/tod_agents.py create mode 100644 parlai/core/tod/tod_test_utils/test_agents.py create mode 100644 tests/tod/test_tod_agents_and_teachers.py create mode 100644 tests/tod/test_tod_teacher_metrics.py diff --git a/conftest.py b/conftest.py index 7cc1e262461..8970273a460 100644 --- a/conftest.py +++ b/conftest.py @@ -67,6 +67,7 @@ def filter_tests_with_circleci(test_list): ('datatests/', 'data'), ('parlai/tasks/', 'teacher'), ('tasks/', 'tasks'), + ('tod/', 'tod'), ] diff --git a/parlai/core/tod/teacher_metrics.py b/parlai/core/tod/teacher_metrics.py new file mode 100644 index 00000000000..3fc85c7d107 --- /dev/null +++ b/parlai/core/tod/teacher_metrics.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) teacher metrics. +""" +from typing import Optional, List, Dict, Any +from parlai.core.metrics import AverageMetric, BleuMetric, F1Metric, Metric, Metrics + + +class SlotMetrics(Metrics): + """ + Helper container which encapsulates standard slot metrics in task oriented learning + (jga, slot_p, slot_r, etc). + + Due to differences in dialogue representations between tasks, the input is pre- + parsed ground truth and predicted slot dictionaries. + + The 'jga+nlg' metric assumes a balanced set of JGA and NLG scores such that + 2 * Avg(JGA, NLG_BLEU) = Avg(JGA + NLG_BLEU) + The `jga+nlg` metric assumes that `NlgMetrics` is used to calculated the other side. + """ + + def __init__( + self, + teacher_slots: Dict[str, str], + predicted_slots: Dict[str, str], + prefixes: Optional[List] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + # jga and optionally Avg(jga,nlg_bleu) + self.add_with_prefixes("jga", AverageMetric(teacher_slots == predicted_slots)) + if len(teacher_slots) > 0: + self.add_with_prefixes( + "jga_noempty", AverageMetric(teacher_slots == predicted_slots) + ) + else: + self.add_with_prefixes( + "jga_empty", AverageMetric(teacher_slots == predicted_slots) + ) + + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), NlgMetrics class (below) adds NLG-BLEU + self.add("jga+nlg", AverageMetric(teacher_slots == predicted_slots)) + # precision + for pred_slot_name, pred_value in predicted_slots.items(): + slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value) + self.add_with_prefixes("slot_p", slot_p) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_p=slot_p)) + # recall + for teacher_slot_name, teacher_value in teacher_slots.items(): + slot_r = AverageMetric( + predicted_slots.get(teacher_slot_name) == teacher_value + ) + self.add_with_prefixes("slot_r", slot_r) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_r=slot_r)) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +class NlgMetrics(Metrics): + """ + Helper container for generation version of standard metrics (F1, BLEU, ..). + """ + + def __init__( + self, + guess: str, + labels: Optional[List[str]], + prefixes: Optional[List[str]] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + bleu = BleuMetric.compute(guess, labels) + f1 = F1Metric.compute(guess, labels) + self.add_with_prefixes("nlg_bleu", bleu) + self.add_with_prefixes("nlg_f1", f1) + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), SlotMetrics class (above) adds JGA + self.add("jga+nlg", bleu) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +AverageType = Optional[AverageMetric] + + +def _average_type_sum_helper(first: AverageType, second: AverageType) -> AverageType: + """ + Helper to deal with Nones. + + We are "clever" in how we aggregate SlotF1Metrics (See SlotMetrics `__init__`) in + that we add precision and recall values separately, but this means we need to handle + None. + """ + if first is None: + return second + if second is None: + return first + return first + second + + +class SlotF1Metric(Metric): + """ + Metric to keep track of slot F1. + + Keeps track of slot precision and slot recall as running metrics. + """ + + __slots__ = ("_slot_p", "_slot_r") + + @property + def macro_average(self) -> bool: + """ + Indicates whether this metric should be macro-averaged when globally reported. + """ + return True + + def __init__(self, slot_p: AverageType = None, slot_r: AverageType = None): + if not isinstance(slot_p, AverageMetric) and slot_p is not None: + slot_p = AverageMetric(slot_p) + if not isinstance(slot_r, AverageMetric) and slot_r is not None: + slot_r = AverageMetric(slot_r) + self._slot_p = slot_p + self._slot_r = slot_r + + def __add__(self, other: Optional["SlotF1Metric"]) -> "SlotF1Metric": + # NOTE: hinting can be cleaned up with "from __future__ import annotations" when + # we drop Python 3.6 + if other is None: + return self + slot_p = _average_type_sum_helper(self._slot_p, other._slot_p) + slot_r = _average_type_sum_helper(self._slot_r, other._slot_r) + return type(self)(slot_p=slot_p, slot_r=slot_r) + + def value(self) -> float: + if self._slot_p is None or self._slot_r is None: + return float("nan") + else: + slot_p = self._slot_p.value() + slot_r = self._slot_r.value() + if slot_p == 0.0 and slot_r == 0.0: + return float("nan") + else: + return 2 * (slot_p * slot_r) / (slot_p + slot_r) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py new file mode 100644 index 00000000000..f22a8330760 --- /dev/null +++ b/parlai/core/tod/tod_agents.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Agents (used for dumping data) and Teachers (for training models) related to the TOD +conversation setup. + +# Usage + +For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. + +For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. + +Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for usage in generating simulations. + +As a convention, agents and teachers that are inheritable are prefixed with "Tod" whereas those that can be used as-is are not. Similarly, classes and functions that do not need to be exposed outside of this file are prefixed with a single underscore ('_'). + +## Why we do this +These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. + +# Description of Agents + Teachers useful for Simulation +## Teachers for training (generative) models + * TodSystemTeacher + * TodUserSimulatorTeacher + +## Agents for Grounding +For goal grounding for the User for simulation: + * TodGoalAgent + * TodSingleGoalAgent + +For (optional) API schema grounding for the System: + * TodApiSchemaAgent (must be used with `TodGoalAgent` only) + * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) + * EmptyApiSchemaAgent + * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. + +## Agents for mocking APIs: + * StandaloneApiAgent + * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher` + +# Agents for dumping data from a ground truth dataset +The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. + * TodApiCallAndSysUttAgent + * TodApiResponseAgent + * TodUserUttAgent + +For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. + +# Other agents +There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. +""" + +from parlai.core.agents import Agent +from parlai.core.message import Message +from parlai.core.metrics import AverageMetric +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.core.teachers import DialogTeacher +from parlai.utils.distributed import is_distributed, get_rank, num_workers + +import parlai.core.tod.tod_core as tod +from parlai.core.tod.tod_core import SerializationHelpers +from parlai.core.tod.teacher_metrics import SlotMetrics, NlgMetrics + +from typing import Optional, List +import json +import pickle +import difflib +import random +from math import ceil + + +######### Agents that dump information from a dataset; base classes +class TodStructuredDataParser(Agent): + """ + Base class that specifies intermediate representations for Tod conversations. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + if hasattr(super(), "add_cmdline_args"): + parser = super().add_cmdline_args(parser, partial_opt) + group = parser.add_argument_group("TOD StructuredData agent") + group.add_argument( + "--episodes-randomization-seed", + type=int, + default=-1, + help="Randomize episodes in a predictable way (eg, for few shot). Set to -1 for no randomization. ", + ) + parser.add_argument( + "--n-shot", + default=-1, + type=int, + help="Number of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--percent-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + parser.add_argument( + "--percent-shot", + default=-1, + type=float, + help="Percentage of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--n-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + return parser + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.id = self.get_id_task_prefix() + "_" + self._get_agent_type_suffix() + if shared is None: + self.episodes = self.generate_episodes() + else: + self.episodes = shared["episodes"] + + def share(self): + share = super().share() + share["episodes"] = self.episodes + return share + + def setup_episodes(self, fold: str) -> List[tod.TodStructuredEpisode]: + """ + Fold here is a data fold. + """ + raise NotImplementedError( + "Must have method for generating an episode. Must be set in downstream Parser for a given task" + ) + + def generate_episodes(self) -> List[tod.TodStructuredEpisode]: + if self.opt.get("n_shot", -1) >= 0 and self.opt.get("percent_shot", -1) >= 0: + # Validate before spending a while to load eeverything + raise RuntimeError("Both `--n-shot` and `--percent-shot` in use!") + episodes = list(self.setup_episodes(self.fold)) + if self.opt.get("episodes_randomization_seed", -1) != -1: + random.Random(self.opt["episodes_randomization_seed"]).shuffle(episodes) + if self.opt.get("n_shot", -1) != -1: + episodes = episodes[: self.opt["n_shot"]] + elif self.opt.get("percent_shot", -1) >= 0: + episodes = episodes[: int(len(episodes) * self.opt["percent_shot"])] + return episodes + + def get_id_task_prefix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set ID prefix in downstream task agent. Must be set in downsream Parser for a given task" + ) + + def _get_agent_type_suffix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set in downstream agent within `tod_agents`. If you see this error, something is wrong with TOD Infrastructure" + ) + + +######### Agents that dump information from a dataset as gold (explicitly should *not* be used with teachers) +class _TodDataDumpAgent(TodStructuredDataParser): + """ + For agents which dump data from some dataset, without training/other modifications. + + Implements an "epoch done" + + Member variables assumed to be set in init downstream: + self.fold + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + self.batchsize = opt.get("batchsize", 1) + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + self.episode_idx = opt.get("batchindex", 0) + self._setup_next_episode() + self.round_idx = 0 # for some downstream utt + sysUttAndApiCallAgents. + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.episode_idx += rank * chunk_size + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + def _setup_next_episode(self): + self.epochDone = not self.episode_idx < self.max_episodes + self.episode = None + if not self.epochDone: + self.episode = self.episodes[self.episode_idx] + self.round_idx = ( + 0 # so downstream agents know which round they are in. Update in `act()` + ) + + def epoch_done(self) -> bool: + return self.epochDone + + def episode_done(self) -> bool: + """ + This is not actually "episode_done" so much as "we want to signify to the world + that we have gone past the batch". + + This class should not control whether or not the episode is actually done since + the TodWorld expects that to come from the User agent. + """ + return self.epochDone + + def num_episodes(self) -> int: + return len(self.episodes) + + def reset(self): + self.episode_idx += self.batchsize + self._setup_next_episode() + + +class TodGoalAgent(_TodDataDumpAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + """ + + def act(self): + return { + "text": f"{tod.STANDARD_GOAL}{self.episode.goal_calls_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "Goal" + + +class TodApiSchemaAgent(_TodDataDumpAgent): + def act(self): + return { + "text": f"{tod.STANDARD_API_SCHEMAS}{self.episode.api_schemas_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "ApiSchema" + + +############# Single Goal + Api Schema Agent +class _EpisodeToSingleGoalProcessor(_TodDataDumpAgent): + """ + Iterate through all of the goals of a dataset, one by one. + + Slightly different logic than the dump agent since how we count + setup examples for + an episode are different + + Used as a mixin in the SingleGoal and SingleApiSchema agents below. + + This class exposes a `filter_goals()` function that can be overridden by downstream agents. + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + if shared is None: + self.episodes = self._setup_single_goal_episodes() + else: + # Handled fine in _TodDataDumpAgent + pass + + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + self._setup_next_episode() + + def _setup_single_goal_episodes(self) -> List[tod.TodStructuredEpisode]: + """ + This function assumes that `self.setup_episodes()` has already been called + prior. + + Based on the `__init__` order of this class, it should be done in + `TodStructuredDataParser` by this point. + """ + raw_episodes = self.episodes + result = [] + for raw in raw_episodes: + for call in self.filter_goals(raw.goal_calls_machine): + schema = {} + for cand in raw.api_schemas_machine: + if ( + cand[tod.STANDARD_API_NAME_SLOT] + == call[tod.STANDARD_API_NAME_SLOT] + ): + schema = cand + + result.append( + tod.TodStructuredEpisode( + domain=raw.domain, + api_schemas_machine=[schema], + goal_calls_machine=[call], + rounds=[], + ) + ) + return result + + def filter_goals(self, goals): + """ + Some downstream agents may want to filter the goals. + + Override this if so. + """ + return goals + + +class TodSingleGoalAgent(_EpisodeToSingleGoalProcessor, TodGoalAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: If an API schema agent is used, this *must* be used with `TodSingleApiSchemaAgent` since it will be nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleGoal" + + +class TodSingleApiSchemaAgent(_EpisodeToSingleGoalProcessor, TodApiSchemaAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: Must be used with TodSingleGoalAgent since nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleApiSchema" + + +###### Agents used for calculating TOD World Metrics based on a dataset. See `tod_world_script` or `parlai/projects/tod_simulator/` for examples. +class TodUserUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "User" agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_USER_UTTERANCE}{self.episode.rounds[self.round_idx].user_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "User" + + +class TodApiCallAndSysUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "System" + agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def __init__(self, opt: Opt, shared=None): + # This class represents two "agents" so need to make sure we don't increment episode number (reset) twice + self.already_reset = False + self.api_call_turn = True + super().__init__(opt, shared) + + def act(self): + self.already_reset = False + if tod.STANDARD_API_SCHEMAS in self.observation.get("text", ""): + return { + "text": tod.STANDARD_API_SCHEMAS, + "id": self.id, + "domain": self.episode.domain, + "episode_down": False, + } + + if self.api_call_turn: # comes first, don't iterate round # + result = { + "text": f"{tod.STANDARD_CALL}{self.episode.rounds[self.round_idx].api_call_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + else: + result = { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{self.episode.rounds[self.round_idx].sys_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + + self.api_call_turn ^= True + return result + + def reset(self): + if not self.already_reset: + super().reset() # setup next episode + self.api_call_turn = True + self.already_reset = True + + def _get_agent_type_suffix(self): + return "System" + + +class TodApiResponseAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the API + Simulator. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_RESP}{self.episode.rounds[self.round_idx].api_resp_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "ApiResponse" + + +###### Standalone API agent +class StandaloneApiAgent(Agent): + """ + Trainable agent that saves API calls and responses. + + Use `TodStandaloneApiTeacher` to train this class. For example for a MultiWoz V2.2 + standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the + `.pickle` file to use. + """ + + EMPTY_RESP = { + "text": tod.STANDARD_RESP, + "id": "StandaloneApiAgent", + "episode_done": False, + } + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + group = parser.add_argument_group("TOD Standalone API args") + group.add_argument( + "--exact-api-call", + type=bool, + default=True, + help="Validation-time flag. If true, will return '' if exact api call values not found. If false, will pick response from the same intent with similar api parameters (assuming intent is the same when available)", + ) + + group.add_argument( + "--fail-hard", + type=bool, + default=False, + help="Aids in deugging. Will throw exception if API call not found and '--exact-api-call' is set.", + ) + + group.add_argument( + "--standalone-api-file", + type=str, + default=None, + help="Path to file holding `.pickle` of standalone api for validation (will intelligently strip if suffix included). If not set, assumes the `model_file` argument will contain the `.pickle` file. ", + ) + return parser + + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self.id = "StandaloneApiAgent" + file_key = "model_file" + if self.opt["standalone_api_file"] is not None: + file_key = "standalone_api_file" + self.path_base = self.opt[file_key].replace(".pickle", "") + self.db_path = self.path_base + ".pickle" + self.exact_api_call = self.opt["exact_api_call"] + try: + with (open(self.db_path, "rb")) as openfile: + self.data = pickle.load(openfile) + self.training = True + print("Loaded Standalone API data successfully") + if self.exact_api_call != self.data.get("exact_api_call", True): + raise RuntimeError( + f"Standalone API .pickle file generated with `exact_api_call` of {self.data.get('exact_api_call', False)} but StandaloneApiAgent sets it to {self.exact_api_call}" + ) + except Exception: + print(f"No file at {self.db_path}; ASSUMING WE ARE TRAINING") + self.data = {} + self.data["exact_api_call"] = self.exact_api_call + self.training = True + + def _maybe_filter_prefix(self, text, prefix): + if prefix in text: + return text[len(prefix) :].strip() + return text.strip() + + def act(self): + if not self.observation["text"].startswith(tod.STANDARD_CALL): + return self.EMPTY_RESP + call_text_raw = self.observation["text"] + # decode then reencode the API call so that we get the API calls in a consistent order + call_text = SerializationHelpers.api_dict_to_str( + SerializationHelpers.str_to_api_dict( + call_text_raw[len(tod.STANDARD_CALL) :] + ) + ) + if "labels" in self.observation: + return self._do_train(call_text) + return self._do_fetch(call_text) + + def _do_train(self, call_text): + assert self.training is True + self.data[call_text] = self.observation["labels"][0] + return self.EMPTY_RESP + + def _do_fetch(self, call_text): + if self.exact_api_call: + if self.opt.get("fail_hard", False): + resp = self.data[call_text] + else: + resp = self.data.get(call_text, tod.STANDARD_RESP) + return { + "text": resp, + "id": self.id, + "episode_done": False, + } + + # Not exact case + best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) + if len(best_key) == 0: + return self.EMPTY_RESP + return { + "text": self.data.get(best_key[0], tod.STANDARD_RESP), + "id": self.id, + "episode_done": False, + } + + def shutdown(self): + if self.training: + with (open(self.db_path, "wb")) as openfile: + pickle.dump(self.data, openfile) + print(f"Dumped output to {self.db_path}") + with open(self.path_base + ".opt", "w") as f: + json.dump(self.opt, f) + + +######### Empty agents +class EmptyApiSchemaAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyApiSchemaAgent" + + def act(self): + msg = { + "id": self.getID(), + "text": tod.STANDARD_API_SCHEMAS, + "episode_done": False, + } + return Message(msg) + + +class EmptyGoalAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyGoalAgent" + + def act(self): + msg = {"id": self.getID(), "text": tod.STANDARD_GOAL, "episode_done": False} + return Message(msg) + + +############# Teachers +class TodSystemTeacher(TodStructuredDataParser, DialogTeacher): + """ + TOD agent teacher which produces both API calls and NLG responses. + + First turn is API Schema grounding, which may be a an empty schema. + Subsequent turns alternate between + 1. User utterance -> API Call + 2. API Response -> System Utterance + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--api-schemas", + type="bool", + default=False, + help="Preempt first turn with intents + required/optional parameters as key/value for given domain", + ) + parser.add_argument( + "--api-jga-record", + type=bool, + default=True, + help="Should we save jga information per api schema?", + ) + parser.add_argument( + "--domain-jga-record", + type=bool, + default=False, + help="Should we save jga information per domain?", + ) + parser.add_argument( + "--domain-nlg-record", + type=bool, + default=False, + help="Should we save nlg information per domain?", + ) + return parser + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_CALL: + if resp.startswith(tod.STANDARD_CALL): + resp = resp[len(tod.STANDARD_CALL) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + domains = ( + [teacher_action["domain"]] if self.opt["domain_jga_record"] else [] + ) + + metrics = SlotMetrics( + teacher_slots=teacher_action["slots"], + predicted_slots=predicted, + avg_jga_nlg_bleu=True, + prefixes=domains, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + if self.opt["api_jga_record"] and len(teacher_action["slots"]) > 0: + teacher = teacher_action["slots"] + slots = list(teacher.keys()) + slots.remove(tod.STANDARD_API_NAME_SLOT) + api_here = ( + "api-" + + teacher[tod.STANDARD_API_NAME_SLOT] + + "--" + + "-".join(slots) + ) + self.metrics.add(f"{api_here}/jga", AverageMetric(teacher == predicted)) + + elif teacher_action["type"] == tod.STANDARD_SYSTEM_UTTERANCE: + domains = ( + [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] + ) + metrics = NlgMetrics( + guess=resp, + labels=labels, + prefixes=domains, + avg_jga_nlg_bleu=True, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if self.opt.get("api_schemas"): + schemas = episode.api_schemas_utt + else: + schemas = "" + yield { + "text": f"{tod.STANDARD_API_SCHEMAS}{schemas}", + "label": f"{tod.STANDARD_API_SCHEMAS}", + "domain": episode.domain, + "type": tod.STANDARD_API_SCHEMAS, + "slots": {}, + }, True + for r in episode.rounds: + yield { + "text": f"{tod.STANDARD_USER_UTTERANCE}{r.user_utt}", + "label": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "domain": episode.domain, + "type": tod.STANDARD_CALL, + "slots": r.api_call_machine, + }, False + yield { + "text": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "label": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "domain": episode.domain, + "slots": r.api_resp_machine, + "type": tod.STANDARD_SYSTEM_UTTERANCE, + }, False + + def _get_agent_type_suffix(self): + return "SystemTeacher" + + +class TodUserSimulatorTeacher(TodStructuredDataParser, DialogTeacher): + """ + Teacher that has `Goal->User Utterance` for its first turn, then `System + Utterance->User Utterance` for all subsequent turns. + """ + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if len(episode.rounds) < 1: + continue + yield { + "text": f"{tod.STANDARD_GOAL}{episode.goal_calls_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[0].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + }, True + for i, r in enumerate(episode.rounds): + if i == len(episode.rounds) - 1: + continue + yield { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[i+1].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + "slots": {}, # slots in agent/user turns are meaningless + }, False + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_RESP: + if resp.startswith(tod.STANDARD_RESP): + resp = resp[len(tod.STANDARD_RESP) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + + metrics = SlotMetrics(teacher_action["slots"], predicted).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + elif teacher_action["type"] == tod.STANDARD_USER_UTTERANCE: + metrics = NlgMetrics(resp, labels).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def _get_agent_type_suffix(self): + return "UserSimulatorTeacher" + + +class TodStandaloneApiTeacher(TodStructuredDataParser, DialogTeacher): + """ + Use this to generate a database for `StandaloneApiAgent`. + + Set this as the teacher with `StandaloneApiAgent` as the agent. Ex for a MultiWoz + V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` + """ + + def setup_data(self, fold): + # As a default, just put everything in + for fold_overwrite in ["train", "valid", "test"]: + for episode in self.setup_episodes(fold_overwrite): + first = True + for r in episode.rounds: + if len(r.api_call_machine) > 0: + yield { + "text": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "label": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "id": self.id, + "domain": episode.domain, + }, first + first = False + + def _get_agent_type_suffix(self): + return "StandaloneApiTeacher" diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py new file mode 100644 index 00000000000..b1339052764 --- /dev/null +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helpers so we don't need to create agents all over. +""" + +import parlai.core.tod.tod_agents as tod_agents +import parlai.core.tod.tod_core as tod_core + +import os + +API_DATABASE_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "standalone_api_file.pickle" +) + + +def episode_has_broken_api_turn(episode_idx, max_turns): + return episode_idx % 2 == 1 and max_turns > 0 + + +def use_broken_api_calls_this_turn(round_idx, episode_idx): + return episode_idx % 2 == 1 and round_idx % 3 == 1 + + +def make_api_call_machine(round_idx, episode_idx=0, use_broken_mock_api_calls=False): + if round_idx == 0: + return {} + if use_broken_mock_api_calls: + # Hack as a way to test metrics reporting in tod world script + if use_broken_api_calls_this_turn(round_idx, episode_idx): + round_idx = -1 * round_idx + return {tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", "in": round_idx} + + +def make_api_resp_machine(round_idx): + if round_idx == 0: + return {} + return {"out": round_idx} + + +def make_api_schemas_machine(max_rounds): + return [ + { + tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", + tod_core.STANDARD_REQUIRED_KEY: ["in"], + tod_core.STANDARD_OPTIONAL_KEY: [], + } + for round_idx in range(1, max_rounds) + ] + + +def make_goal_calls_machine(max_rounds): + return [make_api_call_machine(x) for x in range(1, max_rounds)] + + +def get_rounds(episode_idx, max_rounds, use_broken_mock_api_calls=False): + return [ + tod_core.TodStructuredRound( + user_utt=f"user_utt_{episode_idx}_{round_idx}", + api_call_machine=make_api_call_machine( + round_idx, episode_idx, use_broken_mock_api_calls + ), + api_resp_machine=make_api_resp_machine(round_idx), + sys_utt=f"sys_utt_{episode_idx}_{round_idx}", + ) + for round_idx in range(max_rounds) + ] + + +def get_round_utts(episode_idx, max_rounds, filter_utts=None): + if max_rounds < 1: + return [] + utts = [ + [ + f"USER: user_utt_{episode_idx}_0", + "APICALL: ", + "APIRESP: ", + f"SYSTEM: sys_utt_{episode_idx}_0", + ] + ] + for i in range(1, max_rounds): + utts.append( + [ + f"USER: user_utt_{episode_idx}_{i}", + f"APICALL: api_name = name_{i} ; in = {i}", + f"APIRESP: out = {i}", + f"SYSTEM: sys_utt_{episode_idx}_{i}", + ] + ) + utts.append( + [ + "USER: [DONE]", + "APICALL: ", + "APIRESP: ", + "SYSTEM: ", + ] + ) + if filter_utts is not None: + utts = [ + [turn for i, turn in enumerate(round_data) if filter_utts[i]] + for round_data in utts + ] + return utts + + +TEST_NUM_EPISODES_OPT_KEY = "test_num_episodes" +TEST_NUM_ROUNDS_OPT_KEY = "test_num_rounds" + +# No api calls in this setup +EPISODE_SETUP__UTTERANCES_ONLY = { + TEST_NUM_ROUNDS_OPT_KEY: 1, + TEST_NUM_EPISODES_OPT_KEY: 1, +} + +# No one call, one goal, one api desscription in this setup +EPISODE_SETUP__SINGLE_API_CALL = { + TEST_NUM_ROUNDS_OPT_KEY: 2, + TEST_NUM_EPISODES_OPT_KEY: 1, +} +# Will start testing multiple api calls + schemas, multi-round logic +EPISODE_SETUP__MULTI_ROUND = {TEST_NUM_ROUNDS_OPT_KEY: 5, TEST_NUM_EPISODES_OPT_KEY: 1} + +# Test that episode logic is correct +EPISODE_SETUP__MULTI_EPISODE = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 8, +} + +# Test that episode + pesky-off-by-one batchinglogic is correct +EPISODE_SETUP__MULTI_EPISODE_BS = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 35, +} + + +class TestDataParser(tod_agents.TodStructuredDataParser): + """ + Assume that when we init, we init w/ num of episodes + rounds as opts. + """ + + def __init__(self, opt, shared=None): + opt["datafile"] = "DUMMY" + self.fold = "DUMMY" + # Following lines are only reelvant in training the standalone api teacher + if TEST_NUM_EPISODES_OPT_KEY not in opt: + opt[TEST_NUM_EPISODES_OPT_KEY] = 35 + if TEST_NUM_ROUNDS_OPT_KEY not in opt: + opt[TEST_NUM_ROUNDS_OPT_KEY] = 5 + super().__init__(opt, shared) + + def setup_episodes(self, _): + result = [] + for ep_idx in range(0, self.opt[TEST_NUM_EPISODES_OPT_KEY]): + result.append( + tod_core.TodStructuredEpisode( + goal_calls_machine=[ + make_api_call_machine(x) + for x in range(1, self.opt[TEST_NUM_ROUNDS_OPT_KEY]) + ], + api_schemas_machine=make_api_schemas_machine( + self.opt[TEST_NUM_ROUNDS_OPT_KEY] + ), + rounds=get_rounds( + ep_idx, + self.opt[TEST_NUM_ROUNDS_OPT_KEY], + self.opt.get("use_broken_mock_api_calls", False), + ), + ) + ) + # print(result, self.opt) + return result + + def get_id_task_prefix(self): + return "Test" + + +class SystemTeacher(TestDataParser, tod_agents.TodSystemTeacher): + pass + + +class UserSimulatorTeacher(TestDataParser, tod_agents.TodUserSimulatorTeacher): + pass + + +class StandaloneApiTeacher(TestDataParser, tod_agents.TodStandaloneApiTeacher): + pass + + +class GoalAgent(TestDataParser, tod_agents.TodGoalAgent): + pass + + +class ApiSchemaAgent(TestDataParser, tod_agents.TodApiSchemaAgent): + pass + + +class SingleGoalAgent(TestDataParser, tod_agents.TodSingleGoalAgent): + pass + + +class SingleApiSchemaAgent(TestDataParser, tod_agents.TodSingleApiSchemaAgent): + pass + + +# Tested in tod world code +class UserUttAgent(TestDataParser, tod_agents.TodUserUttAgent): + pass + + +# Tested in tod world code +class ApiCallAndSysUttAgent(TestDataParser, tod_agents.TodApiCallAndSysUttAgent): + pass diff --git a/pytest.ini b/pytest.ini index 9aec92c0a89..d4095288194 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,3 +12,4 @@ markers = unit internal nofbcode + tod diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py new file mode 100644 index 00000000000..5383c72416d --- /dev/null +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests different (more complicated) slot metrics. +""" + +import unittest + +import copy +import parlai.core.tod.tod_core as tod_core +import parlai.core.tod.tod_test_utils.test_agents as test_agents + + +class TestTodAgentsAndTeachersBase(unittest.TestCase): + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + full_opts["episodes_randomization_seed"] = -1 # no random here + return class_type(full_opts) + + def dump_single_utt_per_episode_agent_text(self, class_type, round_opt, opt): + agent = self.setup_agent_or_teacher(class_type, round_opt, opt) + result = [] + while not agent.epoch_done(): + result.append(agent.act()["text"]) + agent.reset() + return result + + def dump_teacher_text(self, class_type, round_opt, opt): + """ + Array where [episode_idx][turn_idx][text=0,label=1] + """ + teacher = self.setup_agent_or_teacher(class_type, round_opt, opt) + data = [] + here = [] + for x, new in teacher.setup_data("dummy"): + if new and len(here) > 0: + data.append(copy.deepcopy(here)) + here = [] + here.append([x["text"], x["label"]]) + if len(here) > 0: + data.append(here) + return data + + def _test_roundDataCorrect(self): + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__UTTERANCES_ONLY) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + + +class TestSystemTeacher(TestTodAgentsAndTeachersBase): + def test_apiSchemas_with_yesApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": True}, + ) + self.assertEqual( + values[0][0][0], + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(2) + ), + ) + + def test_apiSchemas_with_noApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": False}, + ) + self.assertEqual(values[0][0][0], "APIS: ") + + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.SystemTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + for utt in utts: + comp.append([utt[0], utt[1]]) + comp.append([utt[2], utt[3]]) + # Skip context turn cause we check it above + self.assertEqual(episode[1:], comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestUserTeacher(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.UserSimulatorTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + comp.append( + [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ), + utts[0][0], + ] + ) + last_sys = utts[0][3] + for i in range(1, len(utts)): + comp.append([last_sys, utts[i][0]]) + last_sys = utts[i][3] + self.assertEqual(episode, comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.GoalAgent, config, {} + ) + + goal_text = [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.ApiSchemaAgent, config, {} + ) + + apis_texts = [ + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, apis_texts) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + + goal_text = [] + for _ in range(max_episodes): + goals = test_agents.make_goal_calls_machine(max_rounds) + for x in goals: + goal_text.append( + "GOAL: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + apis_text = [] + for _ in range(max_episodes): + apis = test_agents.make_api_schemas_machine(max_rounds) + for x in apis: + apis_text.append( + "APIS: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + self.assertEqual(values, apis_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalWithSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + """ + Make sure the SingleGoal + SingleApiSchema agents correspond. + """ + + def _test_roundDataCorrect_helper(self, config): + goals = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + apis = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + for i in range(len(goals)): + goal = tod_core.SerializationHelpers.str_to_goals(goals[i][len("GOALS:") :]) + api = tod_core.SerializationHelpers.str_to_api_schemas( + apis[i][len("APIS:") :] + ) + self.assertEqual( + goal[0].get("api_name", None), api[0].get("api_name", None) + ) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestLowShot(TestTodAgentsAndTeachersBase): + FEW_SHOT_SAMPLES = [0, 1, 5, 15] + PERCENTAGES = [0, 0.1, 0.3, 0.5] + + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + return class_type(full_opts) + + def test_few_shot_lengths_correct(self): + def helper(n_shot): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, + { + "episodes_randomization_seed": 0, + "n_shot": n_shot, + }, + ) + self.assertEqual(len(values), n_shot) + + for i in self.FEW_SHOT_SAMPLES: + helper(i) + + def _test_subsets(self, data_dumps): + for i in range(len(data_dumps) - 1): + small = data_dumps[i] + larger = data_dumps[i + 1] + for i, episode in enumerate(small): + self.assertEqual(episode, larger[i]) + + def test_few_shot_subset(self): + def helper(n_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE, + { + "episodes_randomization_seed": seed, + "n_shot": n_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_three) + self.assertNotEqual(data_dumps_seed_zero[-1], data_dumps_seed_three[-1]) + + def test_percent_shot_lengths_correct(self): + def helper(percent_shot, correct): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": 0, + "percent_shot": percent_shot, + }, + ) + self.assertEqual(len(values), correct) + + helper(0, 0) + helper(0.1, 3) + helper(0.3, 10) + + def test_percent_shot_subset(self): + def helper(percent_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": seed, + "percent_shot": percent_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_three) + + def test_correct_throw_when_both_shots_defined(self): + self.assertRaises( + RuntimeError, + self.dump_teacher_text, + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + {"episodes_randomization_seed": 0, "percent_shot": 0.3, "n_shot": 3}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py new file mode 100644 index 00000000000..aec4aba40f6 --- /dev/null +++ b/tests/tod/test_tod_teacher_metrics.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from math import isnan + +from parlai.core.metrics import AverageMetric +from parlai.core.tod.teacher_metrics import SlotF1Metric, SlotMetrics + + +class TestSlotF1Metric(unittest.TestCase): + """ + Test SlotF1Metric. + """ + + def test_slot_f1_metric_inputs(self): + slots_p_r_and_f1 = [ + (None, None, float("nan")), + (None, AverageMetric(0.0), float("nan")), + (AverageMetric(0.0), AverageMetric(0.0), float("nan")), + (AverageMetric(1), AverageMetric(1), 1.0), + (AverageMetric(1), AverageMetric(0), 0.0), + (AverageMetric(0.25), AverageMetric(0.75), 0.375), + ] + for slot_p, slot_r, slot_f1 in slots_p_r_and_f1: + actual_slot_f1 = SlotF1Metric(slot_p=slot_p, slot_r=slot_r).value() + if isnan(slot_f1): + self.assertTrue(isnan(actual_slot_f1)) + else: + self.assertEqual(slot_f1, actual_slot_f1) + + def test_slot_f1_metric_addition(self): + a = SlotF1Metric(slot_p=1) + b = SlotF1Metric(slot_r=0) + c = SlotF1Metric(slot_p=AverageMetric(numer=2, denom=3), slot_r=1) + d = a + b + c + # Slot P should be 3/4 = 0.75; slot R should be 1/2 = 0.5 + self.assertEqual(0.6, d.value()) + + +empty_slots = {} +basic_slots = {"a": "a_val", "b": "b_val", "c": "c_val"} +partial_slots = {"a": "a_val", "other": "other_val"} + + +class TestSlotMetrics(unittest.TestCase): + def test_base_slot_metrics(self): + cases = [ + (empty_slots, empty_slots, {"jga": 1}), + ( + basic_slots, + basic_slots, + {"jga": 1, "slot_p": 1, "slot_r": 1, "slot_f1": 1}, + ), + ( + basic_slots, + partial_slots, + {"jga": 0, "slot_p": 0.5, "slot_r": float(1.0 / 3), "slot_f1": 0.4}, + ), + ] + for teacher, predicted, result in cases: + metric = SlotMetrics( + teacher_slots=teacher, + predicted_slots=predicted, + ) + for key in result: + self.assertEqual(result[key], metric.report()[key]) + + +if __name__ == "__main__": + unittest.main() From 3bf655fa156bb13928a4d72b7f612127cfeefcb2 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 09:19:16 -0800 Subject: [PATCH 03/23] [TOD] Tod json structure to teacher task As noted in the README, this agent takes data generated from `tod_world_script.py` and dumps it out to a teacher. (Note that I tried setting up a regression test for this teacher, but I ran into issues getting it to save the output directory to not be something that included my local homedir name in it..) --- parlai/tasks/tod_json/README.md | 11 ++ parlai/tasks/tod_json/__init__.py | 5 + parlai/tasks/tod_json/agents.py | 188 +++++++++++++++++++++++ parlai/tasks/tod_json/example_data.jsonl | 15 ++ 4 files changed, 219 insertions(+) create mode 100644 parlai/tasks/tod_json/README.md create mode 100644 parlai/tasks/tod_json/__init__.py create mode 100644 parlai/tasks/tod_json/agents.py create mode 100644 parlai/tasks/tod_json/example_data.jsonl diff --git a/parlai/tasks/tod_json/README.md b/parlai/tasks/tod_json/README.md new file mode 100644 index 00000000000..f96d41e7448 --- /dev/null +++ b/parlai/tasks/tod_json/README.md @@ -0,0 +1,11 @@ +# TOD Json Task Agent + +Takes a .jsonl conversation output from model-model chats from `tod_world_script.py`, puts it into the TOD intermediate conversations format so we can use it in a variety of different teachers. + +For example, to see the display of the data: +``` +parlai dd -t tod_json:SystemTeacher --jsonfile-datapath example_data.jsonl +parlai dd -t tod_json:UserSimulatorTeacher --jsonfile-datapath example_data.jsonl +``` + +See the file `example_data.json` in this directory for the format. diff --git a/parlai/tasks/tod_json/__init__.py b/parlai/tasks/tod_json/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/tasks/tod_json/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py new file mode 100644 index 00000000000..55e3514de6c --- /dev/null +++ b/parlai/tasks/tod_json/agents.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# This task simply loads the specified file: useful for quick tests without +# setting up a new task. + +from typing import Optional +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.utils.data import DatatypeHelper + +import parlai.core.tod.tod_agents as tod_agents +import parlai.core.tod.tod_core as tod + +import json +import os + +PREFIXES = [ + tod.STANDARD_USER_UTTERANCE, + tod.STANDARD_CALL, + tod.STANDARD_RESP, + tod.STANDARD_SYSTEM_UTTERANCE, +] + +PREFIXES_PREEMPT = [ + tod.STANDARD_API_SCHEMAS, + tod.STANDARD_API_SCHEMAS, + tod.STANDARD_API_SCHEMAS, + tod.STANDARD_GOAL, +] + + +class JsonTodParser(tod_agents.TodStructuredDataParser): + """ + This module provides access to data in the TOD conversations format. + + See core/tod.py for more info about the format. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + super().add_cmdline_args(parser, partial_opt) + agent = parser.add_argument_group("Tod Json Task Arguments") + agent.add_argument( + "-jfdp", + "--jsonfile-datapath", + type=str, + help="Data file. (Assumed to be in .jsonl)", + ) + agent.add_argument( + "-tmdp", + "--tod-metrics-datapath", + type=str, + default=None, + help="Filter which examples to use from a report including per-turn tod metrics", + ) + agent.add_argument( + "-f-agh", + "--filter-all-goals-hit", + type=bool, + default=False, + help="Filter episodes by all-goals-hit metric being 1. Assumes `tod-metrics-datapath` is set.", + ) + agent.add_argument( + "--split-to-folds", + type=bool, + default=True, + help="Use all data or split into 8:1:1 fold", + ) + agent.add_argument( + "--split-folds-seed", + type=int, + default=42, + help="Seed for the fold split", + ) + return parser + + def __init__(self, opt, shared=None): + if not opt.get("jsonfile_datapath"): + raise RuntimeError("jsonfile_datapath not specified") + if not hasattr(self, "opt"): + self.opt = opt + self.opt["datafile"] = opt["jsonfile_datapath"] + self.fold = self.opt["datatype"] # don't care + # Truncate datafile to just the immediate enclosing folder name and file name + dirname, basename = os.path.split(self.opt["datafile"]) + self.id = os.path.join(os.path.split(dirname)[1], basename) + super().__init__(opt, shared) + + def _process_line(self, line): + blob = json.loads(line) + if "dialog" not in blob or len(blob["dialog"]) < 1: + return None + rounds = [] + for raw_round in blob["dialog"][1:]: + if "prefix_stripped_text" not in raw_round[0]: + for i in range(len(raw_round)): + raw_round[i]["prefix_stripped_text"] = raw_round[i].get( + "text", PREFIXES[i] + )[len(PREFIXES[i]) :] + if len(raw_round) != 4: + if raw_round[0]["prefix_stripped_text"] != tod.STANDARD_DONE: + return None # misformatted convo, don't learn this. + break # TodStructuredEpisode will add in [DONE] + r = tod.TodStructuredRound( + user_utt=raw_round[0]["prefix_stripped_text"], + api_call_machine=tod.SerializationHelpers.str_to_api_dict( + raw_round[1]["prefix_stripped_text"] + ), + api_resp_machine=tod.SerializationHelpers.str_to_api_dict( + raw_round[2]["prefix_stripped_text"] + ), + sys_utt=raw_round[3]["prefix_stripped_text"], + ) + rounds.append(r) + preempt_round = blob["dialog"][0] + if len(preempt_round) != 4: + return None + for i in range(len(preempt_round)): + if "prefix_stripped_text" not in preempt_round[i]: + preempt_round[i]["prefix_stripped_text"] = preempt_round[i].get( + "text", PREFIXES_PREEMPT[i] + )[len(PREFIXES_PREEMPT[i]) :] + + episode = tod.TodStructuredEpisode( + domain=preempt_round[0].get("domain", ""), + api_schemas_machine=tod.SerializationHelpers.str_to_api_schemas( + preempt_round[0].get("prefix_stripped_text", "") + ), + goal_calls_machine=tod.SerializationHelpers.str_to_goals( + preempt_round[3].get("prefix_stripped_text") + ), + rounds=rounds, + ) + return episode + + def setup_episodes(self, fold): + result = [] + if self.opt["tod_metrics_datapath"] is not None: + with open(self.opt["tod_metrics_datapath"]) as f: + report_data = json.load(f) + tod_metrics = report_data["report"]["tod_metrics"] + lines_to_process = [] + with open(self.opt["datafile"], "r") as f: + result = [] + for i, line in enumerate(f.readlines()): + if ( + self.opt["filter_all_goals_hit"] + and tod_metrics[i]["all_goals_hit"] < 0.5 + ): + continue + if line: + lines_to_process.append(line) + + if self.opt["split_to_folds"]: + lines_to_process = DatatypeHelper.split_data_by_fold( + fold, lines_to_process, 0.8, 0.1, 0.1, self.opt["split_folds_seed"] + ) + + for line in lines_to_process: + processed = self._process_line(line) + if processed is not None: + result.append(processed) + return result + + def get_id_task_prefix(self): + return ( + "TodJson_#" + + os.path.basename(self.opt["jsonfile_datapath"]).split(".")[0] + + "#" + ) + + +class SystemTeacher(JsonTodParser, tod_agents.TodSystemTeacher): + pass + + +class DefaultTeacher(SystemTeacher): + pass + + +class UserSimulatorTeacher(JsonTodParser, tod_agents.TodUserSimulatorTeacher): + pass diff --git a/parlai/tasks/tod_json/example_data.jsonl b/parlai/tasks/tod_json/example_data.jsonl new file mode 100644 index 00000000000..8efd5473d82 --- /dev/null +++ b/parlai/tasks/tod_json/example_data.jsonl @@ -0,0 +1,15 @@ +{"dialog": [[{"text": "APIS: api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location", "id": "GoogleSGD_ApiDescription", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 91.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = ShareLocation ; contact_name = Helen ; location = August Hall", "id": "GoogleSGD_Goal", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = ShareLocation ; contact_name = Helen ; location = August Hall"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a contact number for a friend.", "beam_texts": [["USER: I need to find a contact number for a friend.", -4.2508697509765625]], "metrics": {"clen": 24.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a contact number for a friend."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000830456439871341]], "metrics": {"clen": 301.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where would you like to find the contact number?", "beam_texts": [["SYSTEM: Where would you like to find the contact number?", -3.0284223556518555]], "metrics": {"clen": 314.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where would you like to find the contact number?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a contact number for Helen in August Hall.", "beam_texts": [["USER: I need to find a contact number for Helen in August Hall.", -2.9425835609436035]], "metrics": {"clen": 62.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a contact number for Helen in August Hall."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindContact ; contact_name = Helen ; location = August Hall", "beam_texts": [["APICALL: api_name = FindContact ; contact_name = Helen ; location = August Hall", -0.3795860707759857]], "metrics": {"clen": 334.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindContact ; contact_name = Helen ; location = August Hall"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I was unable to find a contact number for Helen. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I was unable to find a contact number for Helen. Is there anything else I can help you with?", -2.7589914798736572]], "metrics": {"clen": 347.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I was unable to find a contact number for Helen. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -1.1491308212280273]], "metrics": {"clen": 97.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0011205199407413602]], "metrics": {"clen": 373.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.7793283462524414]], "metrics": {"clen": 386.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -6.41131464362843e-06]], "metrics": {"clen": 116.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookHouse ; reqArg = check_in_date, check_out_date, number_of_adults, where_to | api_name = SearchHouse ; optArg = has_laundry_service, number_of_adults, rating ; reqArg = where_to | api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date | api_name = FindAttractions ; optArg = category, free_entry, good_for_kids ; reqArg = location", "id": "GoogleSGD_ApiDescription", "domain": "Flights, Hotels, Travel", "episode_done": false, "prefix_stripped_text": "api_name = BookHouse ; reqArg = check_in_date, check_out_date, number_of_adults, where_to | api_name = SearchHouse ; optArg = has_laundry_service, number_of_adults, rating ; reqArg = where_to | api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date | api_name = FindAttractions ; optArg = category, free_entry, good_for_kids ; reqArg = location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -2.1568732222476683e-07]], "metrics": {"clen": 143.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11", "id": "GoogleSGD_Goal", "domain": "Flights, Hotels, Travel", "episode_done": false, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a flight.", "beam_texts": [["USER: I need to find a flight.", -3.2091355323791504]], "metrics": {"clen": 32.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a flight."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0005915741785429418]], "metrics": {"clen": 305.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.660393476486206]], "metrics": {"clen": 318.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm going to Toronto, Canada.", "beam_texts": [["USER: I'm going to Toronto, Canada.", -1.6486107110977173]], "metrics": {"clen": 115.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm going to Toronto, Canada."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0026782331988215446]], "metrics": {"clen": 348.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you leaving?", "beam_texts": [["SYSTEM: When are you leaving?", -2.153712749481201]], "metrics": {"clen": 377.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you leaving?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving on the 2nd of this month.", "beam_texts": [["USER: I'm leaving on the 2nd of this month.", -2.099220037460327]], "metrics": {"clen": 134.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving on the 2nd of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0051135048270225525]], "metrics": {"clen": 410.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -1.2600269317626953]], "metrics": {"clen": 423.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving from Atlanta, GA.", "beam_texts": [["USER: I'm leaving from Atlanta, GA.", -1.2180325984954834]], "metrics": {"clen": 128.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving from Atlanta, GA."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.260057270526886]], "metrics": {"clen": 490.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you returning?", "beam_texts": [["SYSTEM: When are you returning?", -0.7014266848564148]], "metrics": {"clen": 516.0, "ctrunc": 1.0, "ctrunclen": 4.0}, "prefix_stripped_text": "When are you returning?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to return on the 11th of this month.", "beam_texts": [["USER: I want to return on the 11th of this month.", -1.6878958940505981]], "metrics": {"clen": 175.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to return on the 11th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11", "beam_texts": [["APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11", -0.006989812944084406]], "metrics": {"clen": 550.0, "ctrunc": 1.0, "ctrunclen": 38.0}, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-02 ; destination_airport = Toronto ; origin_airport = Atlanta ; return_date = 2019-03-11"}, {"text": "APIRESP: airlines = American Airlines ; departure_date = 2019-03-02 ; destination_airport = Toronto ; inbound_arrival_time = 11:17 ; inbound_departure_time = 04:15 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = Atlanta ; outbound_arrival_time = 12:25 ; outbound_departure_time = 05:10 ; price = 332 ; return_date = 2019-03-11 ; seating_class = Economy", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "airlines = American Airlines ; departure_date = 2019-03-02 ; destination_airport = Toronto ; inbound_arrival_time = 11:17 ; inbound_departure_time = 04:15 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = Atlanta ; outbound_arrival_time = 12:25 ; outbound_departure_time = 05:10 ; price = 332 ; return_date = 2019-03-11 ; seating_class = Economy"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 4 flights. There is an American Airlines flight that leaves at 5:10 am and returns at 4:15 am. The cost is $332.", "beam_texts": [["SYSTEM: I found 4 flights. There is an American Airlines flight that leaves at 5:10 am and returns at 4:15 am. The cost is $332.", -2.6480491161346436]], "metrics": {"clen": 507.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 4 flights. There is an American Airlines flight that leaves at 5:10 am and returns at 4:15 am. The cost is $332."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -1.531357765197754]], "metrics": {"clen": 228.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00013564960681833327]], "metrics": {"clen": 419.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to buy tickets?", "beam_texts": [["SYSTEM: Do you want to buy tickets?", -1.0238951444625854]], "metrics": {"clen": 432.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Do you want to buy tickets?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.9948810338974]], "metrics": {"clen": 183.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00037169820279814303]], "metrics": {"clen": 593.0, "ctrunc": 1.0, "ctrunclen": 81.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.9852166771888733]], "metrics": {"clen": 645.0, "ctrunc": 1.0, "ctrunclen": 133.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.6123467683792114]], "metrics": {"clen": 268.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -2.995951217599213e-05]], "metrics": {"clen": 676.0, "ctrunc": 1.0, "ctrunclen": 164.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.6601892709732056]], "metrics": {"clen": 689.0, "ctrunc": 1.0, "ctrunclen": 177.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.747716791873245e-07]], "metrics": {"clen": 368.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date", "id": "GoogleSGD_ApiDescription", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 115.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle", "id": "GoogleSGD_Goal", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a one way flight.", "beam_texts": [["USER: I need a one way flight.", -2.5476834774017334]], "metrics": {"clen": 26.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a one way flight."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0014316064771264791]], "metrics": {"clen": 247.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.7428438663482666]], "metrics": {"clen": 260.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I am going to LAX.", "beam_texts": [["USER: I am going to LAX.", -2.408816337585449]], "metrics": {"clen": 73.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I am going to LAX."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.002551180776208639]], "metrics": {"clen": 284.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you leaving?", "beam_texts": [["SYSTEM: When are you leaving?", -1.995316982269287]], "metrics": {"clen": 332.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you leaving?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I am leaving on the 6th of this month.", "beam_texts": [["USER: I am leaving on the 6th of this month.", -2.0818636417388916]], "metrics": {"clen": 100.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I am leaving on the 6th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.09947514533996582]], "metrics": {"clen": 360.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -0.6350602507591248]], "metrics": {"clen": 416.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I am leaving from Seattle, WA. I prefer American Airlines.", "beam_texts": [["USER: I am leaving from Seattle, WA. I prefer American Airlines.", -1.8830840587615967]], "metrics": {"clen": 100.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I am leaving from Seattle, WA. I prefer American Airlines."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle", "beam_texts": [["APICALL: airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle", -0.004719942342489958]], "metrics": {"clen": 444.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "airlines = American Airlines ; api_name = SearchOnewayFlight ; departure_date = 2019-03-06 ; destination_airport = Los Angeles ; origin_airport = Seattle"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't find any flights matching your criteria. Can I help with something else?", "beam_texts": [["SYSTEM: I couldn't find any flights matching your criteria. Can I help with something else?", -3.5894200801849365]], "metrics": {"clen": 406.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't find any flights matching your criteria. Can I help with something else?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8101848363876343]], "metrics": {"clen": 177.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00011459175584604964]], "metrics": {"clen": 425.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.8406725525856018]], "metrics": {"clen": 500.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.0990861483151093e-06]], "metrics": {"clen": 244.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location", "id": "GoogleSGD_ApiDescription", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = GetTrainTickets ; optArg = class ; reqArg = date_of_journey, from, journey_start_time, number_of_adults, to, trip_protection | api_name = FindTrains ; optArg = class, number_of_adults ; reqArg = date_of_journey, from, to | api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ShareLocation ; reqArg = contact_name, location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 143.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento", "id": "GoogleSGD_Goal", "domain": "Events, Messaging, Trains", "episode_done": false, "prefix_stripped_text": "api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a bus.", "beam_texts": [["USER: I need to find a bus.", -2.7568917274475098]], "metrics": {"clen": 41.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a bus."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00041706167394295335]], "metrics": {"clen": 197.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.507762908935547]], "metrics": {"clen": 210.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm going to Sacramento, CA.", "beam_texts": [["USER: I'm going to Sacramento, CA.", -1.6594734191894531]], "metrics": {"clen": 62.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm going to Sacramento, CA."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.005608933977782726]], "metrics": {"clen": 229.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When do you want to leave?", "beam_texts": [["SYSTEM: When do you want to leave?", -1.8860359191894531]], "metrics": {"clen": 209.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When do you want to leave?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to leave on the 14th of this month.", "beam_texts": [["USER: I want to leave on the 14th of this month.", -1.5564500093460083]], "metrics": {"clen": 173.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to leave on the 14th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0036852918565273285]], "metrics": {"clen": 237.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -0.7195356488227844]], "metrics": {"clen": 250.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving from San Diego.", "beam_texts": [["USER: I'm leaving from San Diego.", -0.939035177230835]], "metrics": {"clen": 116.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving from San Diego."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento", "beam_texts": [["APICALL: api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento", -0.007453737314790487]], "metrics": {"clen": 274.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindTrains ; date_of_journey = 2019-03-14 ; from = San Diego ; to = Sacramento"}, {"text": "APIRESP: class = Value ; date_of_journey = 2019-03-14 ; from = San Diego ; from_station = Santa Fe Depot ; journey_start_time = 06:10 ; number_of_adults = 1 ; to = Sacramento ; to_station = Sacramento Valley Station ; total = 145 ; trip_protection = True", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "class = Value ; date_of_journey = 2019-03-14 ; from = San Diego ; from_station = Santa Fe Depot ; journey_start_time = 06:10 ; number_of_adults = 1 ; to = Sacramento ; to_station = Sacramento Valley Station ; total = 145 ; trip_protection = True"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 buses. There's a 6:10 am bus for $145. It's a value ticket.", "beam_texts": [["SYSTEM: I found 10 buses. There's a 6:10 am bus for $145. It's a value ticket.", -3.985891103744507]], "metrics": {"clen": 354.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 buses. There's a 6:10 am bus for $145. It's a value ticket."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -1.7473149299621582]], "metrics": {"clen": 135.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0016271225176751614]], "metrics": {"clen": 383.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to buy tickets?", "beam_texts": [["SYSTEM: Do you want to buy tickets?", -0.8871648907661438]], "metrics": {"clen": 563.0, "ctrunc": 1.0, "ctrunclen": 51.0}, "prefix_stripped_text": "Do you want to buy tickets?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.4214577674865723]], "metrics": {"clen": 209.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0008082308922894299]], "metrics": {"clen": 550.0, "ctrunc": 1.0, "ctrunclen": 38.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Anything else I can help with?", "beam_texts": [["SYSTEM: Anything else I can help with?", -1.0824320316314697]], "metrics": {"clen": 563.0, "ctrunc": 1.0, "ctrunclen": 51.0}, "prefix_stripped_text": "Anything else I can help with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8818005323410034]], "metrics": {"clen": 243.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0006076796562410891]], "metrics": {"clen": 519.0, "ctrunc": 1.0, "ctrunclen": 7.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.5481600761413574]], "metrics": {"clen": 532.0, "ctrunc": 1.0, "ctrunclen": 20.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.8318112893211946e-07]], "metrics": {"clen": 350.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date", "id": "GoogleSGD_ApiDescription", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "api_name = SearchOnewayFlight ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport | api_name = SearchRoundtripFlights ; optArg = airlines, number_of_tickets, seating_class ; reqArg = departure_date, destination_airport, origin_airport, return_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 142.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09", "id": "GoogleSGD_Goal", "domain": "Flights", "episode_done": false, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a flight.", "beam_texts": [["USER: I need to find a flight.", -3.215803623199463]], "metrics": {"clen": 41.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a flight."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0010359908919781446]], "metrics": {"clen": 299.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you going?", "beam_texts": [["SYSTEM: Where are you going?", -2.9211363792419434]], "metrics": {"clen": 312.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you going?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm going to San Francisco.", "beam_texts": [["USER: I'm going to San Francisco.", -2.0261998176574707]], "metrics": {"clen": 101.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm going to San Francisco."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0031771883368492126]], "metrics": {"clen": 331.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you leaving?", "beam_texts": [["SYSTEM: When are you leaving?", -2.7696292400360107]], "metrics": {"clen": 297.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you leaving?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving on the 4th of this month.", "beam_texts": [["USER: I'm leaving on the 4th of this month.", -1.9584730863571167]], "metrics": {"clen": 105.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving on the 4th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.017428752034902573]], "metrics": {"clen": 324.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you leaving from?", "beam_texts": [["SYSTEM: Where are you leaving from?", -1.4117623567581177]], "metrics": {"clen": 373.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you leaving from?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I'm leaving from San Diego.", "beam_texts": [["USER: I'm leaving from San Diego.", -0.9567034244537354]], "metrics": {"clen": 158.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I'm leaving from San Diego."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.2181233912706375]], "metrics": {"clen": 363.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When are you returning?", "beam_texts": [["SYSTEM: When are you returning?", -0.7378217577934265]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When are you returning?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to return on the 9th of this month.", "beam_texts": [["USER: I want to return on the 9th of this month.", -1.6836446523666382]], "metrics": {"clen": 201.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to return on the 9th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09", "beam_texts": [["APICALL: api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09", -0.0012267155107110739]], "metrics": {"clen": 396.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = SearchRoundtripFlights ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; origin_airport = San Diego ; return_date = 2019-03-09"}, {"text": "APIRESP: airlines = Alaska Airlines ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; inbound_arrival_time = 14:05 ; inbound_departure_time = 12:25 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = San Diego ; outbound_arrival_time = 09:10 ; outbound_departure_time = 06:15 ; price = 220 ; return_date = 2019-03-09 ; seating_class = Economy", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "airlines = Alaska Airlines ; departure_date = 2019-03-04 ; destination_airport = San Francisco ; inbound_arrival_time = 14:05 ; inbound_departure_time = 12:25 ; is_nonstop = False ; number_of_tickets = 1 ; origin_airport = San Diego ; outbound_arrival_time = 09:10 ; outbound_departure_time = 06:15 ; price = 220 ; return_date = 2019-03-09 ; seating_class = Economy"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 4 flights. There is an Alaska Airlines flight that leaves at 6:15 am and returns at 12:25 pm. The cost is $220.", "beam_texts": [["SYSTEM: I found 4 flights. There is an Alaska Airlines flight that leaves at 6:15 am and returns at 12:25 pm. The cost is $220.", -2.7658262252807617]], "metrics": {"clen": 362.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 4 flights. There is an Alaska Airlines flight that leaves at 6:15 am and returns at 12:25 pm. The cost is $220."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -1.5201750993728638]], "metrics": {"clen": 185.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -4.8251091357087716e-05]], "metrics": {"clen": 582.0, "ctrunc": 1.0, "ctrunclen": 70.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to buy tickets?", "beam_texts": [["SYSTEM: Do you want to buy tickets?", -1.2585415840148926]], "metrics": {"clen": 595.0, "ctrunc": 1.0, "ctrunclen": 83.0}, "prefix_stripped_text": "Do you want to buy tickets?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.9538514614105225]], "metrics": {"clen": 246.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00017885112902149558]], "metrics": {"clen": 582.0, "ctrunc": 1.0, "ctrunclen": 70.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.9863885045051575]], "metrics": {"clen": 595.0, "ctrunc": 1.0, "ctrunclen": 83.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.591467559337616]], "metrics": {"clen": 221.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -7.712452315900009e-06]], "metrics": {"clen": 617.0, "ctrunc": 1.0, "ctrunclen": 105.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a good day.", "beam_texts": [["SYSTEM: Have a good day.", -0.7458349466323853]], "metrics": {"clen": 630.0, "ctrunc": 1.0, "ctrunclen": 118.0}, "prefix_stripped_text": "Have a good day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.8318112893211946e-07]], "metrics": {"clen": 285.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date", "id": "GoogleSGD_ApiDescription", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 168.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-12 ; pickup_time = 09:30 ; start_date = 2019-03-11", "id": "GoogleSGD_Goal", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-12 ; pickup_time = 09:30 ; start_date = 2019-03-11"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a rental car.", "beam_texts": [["USER: I need to find a rental car.", -2.235142469406128]], "metrics": {"clen": 48.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a rental car."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00040876015555113554]], "metrics": {"clen": 351.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -3.3327901363372803]], "metrics": {"clen": 171.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it on the 11th of March.", "beam_texts": [["USER: I need it on the 11th of March.", -2.694045066833496]], "metrics": {"clen": 83.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it on the 11th of March."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00027602165937423706]], "metrics": {"clen": 206.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What time do you want to pick it up?", "beam_texts": [["SYSTEM: What time do you want to pick it up?", -1.9576942920684814]], "metrics": {"clen": 219.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What time do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it at 9:30 in the morning.", "beam_texts": [["USER: I need it at 9:30 in the morning.", -2.4926252365112305]], "metrics": {"clen": 120.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it at 9:30 in the morning."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0010227571474388242]], "metrics": {"clen": 434.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where do you want to pick it up?", "beam_texts": [["SYSTEM: Where do you want to pick it up?", -1.4877737760543823]], "metrics": {"clen": 447.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it in Fresno, CA.", "beam_texts": [["USER: I need it in Fresno, CA.", -1.823162317276001]], "metrics": {"clen": 144.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it in Fresno, CA."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-11 ; pickup_time = 09:30", "beam_texts": [["APICALL: api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-11 ; pickup_time = 09:30", -0.2495831549167633]], "metrics": {"clen": 476.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = GetCarsAvailable ; city = Fresno ; end_date = 2019-03-11 ; pickup_time = 09:30"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", -2.8082520961761475]], "metrics": {"clen": 519.0, "ctrunc": 1.0, "ctrunclen": 7.0}, "prefix_stripped_text": "I couldn't find any cars matching your criteria. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, I'm good.", "beam_texts": [["USER: No, I'm good.", -3.3139195442199707]], "metrics": {"clen": 167.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, I'm good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0005645004566758871]], "metrics": {"clen": 487.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -1.0553348064422607]], "metrics": {"clen": 564.0, "ctrunc": 1.0, "ctrunclen": 52.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.747716791873245e-07]], "metrics": {"clen": 218.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date", "id": "GoogleSGD_ApiDescription", "domain": "Homes", "episode_done": false, "prefix_stripped_text": "api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 59.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02", "id": "GoogleSGD_Goal", "domain": "Homes", "episode_done": false, "prefix_stripped_text": "api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to visit a property.", "beam_texts": [["USER: I want to visit a property.", -3.0068488121032715]], "metrics": {"clen": 22.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to visit a property."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0010161403333768249]], "metrics": {"clen": 142.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: When do you want to visit?", "beam_texts": [["SYSTEM: When do you want to visit?", -1.7299853563308716]], "metrics": {"clen": 178.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "When do you want to visit?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to visit on the 2nd of this month.", "beam_texts": [["USER: I want to visit on the 2nd of this month.", -2.3218560218811035]], "metrics": {"clen": 144.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to visit on the 2nd of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.020279226824641228]], "metrics": {"clen": 200.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What property do you want to visit?", "beam_texts": [["SYSTEM: What property do you want to visit?", -1.9698238372802734]], "metrics": {"clen": 213.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What property do you want to visit?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to visit Indian Creek.", "beam_texts": [["USER: I want to visit Indian Creek.", -0.8281252980232239]], "metrics": {"clen": 97.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to visit Indian Creek."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0015185625525191426]], "metrics": {"clen": 245.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: You want to visit Indian Creek on March 2nd.", "beam_texts": [["SYSTEM: Please confirm: You want to visit Indian Creek on March 2nd.", -2.028269052505493]], "metrics": {"clen": 287.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Please confirm: You want to visit Indian Creek on March 2nd."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that is correct.", "beam_texts": [["USER: Yes, that is correct.", -1.5401211977005005]], "metrics": {"clen": 203.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that is correct."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02", "beam_texts": [["APICALL: api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02", -0.0001516681513749063]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = ScheduleVisit ; property_name = Indian Creek ; visit_date = 2019-03-02"}, {"text": "APIRESP: address = 801 Marine Parkway ; has_garage = False ; in_unit_laundry = True ; number_of_baths = 2 ; number_of_beds = 3 ; phone_number = 650-593-4300 ; price = 3500000 ; property_name = Indian Creek ; visit_date = 2019-03-02", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "address = 801 Marine Parkway ; has_garage = False ; in_unit_laundry = True ; number_of_baths = 2 ; number_of_beds = 3 ; phone_number = 650-593-4300 ; price = 3500000 ; property_name = Indian Creek ; visit_date = 2019-03-02"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Your visit is scheduled.", "beam_texts": [["SYSTEM: Your visit is scheduled.", -0.8813648223876953]], "metrics": {"clen": 369.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Your visit is scheduled."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Thank you.", "beam_texts": [["USER: Thank you.", -2.2147231101989746]], "metrics": {"clen": 136.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -6.387276516761631e-05]], "metrics": {"clen": 392.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.8272175192832947]], "metrics": {"clen": 438.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.5735105276107788]], "metrics": {"clen": 152.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -5.418340879259631e-05]], "metrics": {"clen": 441.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a good day.", "beam_texts": [["SYSTEM: Have a good day.", -0.735770583152771]], "metrics": {"clen": 491.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a good day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.83930285149836e-06]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date", "id": "GoogleSGD_ApiDescription", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 84.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = GetCarsAvailable ; car_type = SUV ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00 ; start_date = 2019-03-05", "id": "GoogleSGD_Goal", "domain": "RentalCars", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; car_type = SUV ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00 ; start_date = 2019-03-05"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a rental car.", "beam_texts": [["USER: I need to find a rental car.", -2.37575101852417]], "metrics": {"clen": 46.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a rental car."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00040876015555113554]], "metrics": {"clen": 183.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -3.3327901363372803]], "metrics": {"clen": 364.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it on the 5th of this month.", "beam_texts": [["USER: I need it on the 5th of this month.", -3.266505479812622]], "metrics": {"clen": 111.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it on the 5th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00025378118152730167]], "metrics": {"clen": 385.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What time do you want to pick it up?", "beam_texts": [["SYSTEM: What time do you want to pick it up?", -1.9800610542297363]], "metrics": {"clen": 409.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What time do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it at 3 in the afternoon.", "beam_texts": [["USER: I need it at 3 in the afternoon.", -3.0727145671844482]], "metrics": {"clen": 131.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it at 3 in the afternoon."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0008782641962170601]], "metrics": {"clen": 243.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where do you want to pick it up?", "beam_texts": [["SYSTEM: Where do you want to pick it up?", -1.5364357233047485]], "metrics": {"clen": 277.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it in Portland, OR.", "beam_texts": [["USER: I need it in Portland, OR.", -1.7654595375061035]], "metrics": {"clen": 155.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it in Portland, OR."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.3837651014328003]], "metrics": {"clen": 306.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -1.704606294631958]], "metrics": {"clen": 332.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need it until the 9th of this month. I want an SUV.", "beam_texts": [["USER: I need it until the 9th of this month. I want an SUV.", -2.499929666519165]], "metrics": {"clen": 181.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need it until the 9th of this month. I want an SUV."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = GetCarsAvailable ; car_type = Full-size ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00", "beam_texts": [["APICALL: api_name = GetCarsAvailable ; car_type = Full-size ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00", -0.31725868582725525]], "metrics": {"clen": 551.0, "ctrunc": 1.0, "ctrunclen": 39.0}, "prefix_stripped_text": "api_name = GetCarsAvailable ; car_type = Full-size ; city = Portland ; end_date = 2019-03-09 ; pickup_time = 15:00"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't find any cars matching your criteria. Is there anything else I can help you with?", -3.00557804107666]], "metrics": {"clen": 409.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't find any cars matching your criteria. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8569039106369019]], "metrics": {"clen": 198.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.01578957960009575]], "metrics": {"clen": 550.0, "ctrunc": 1.0, "ctrunclen": 38.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.99875807762146]], "metrics": {"clen": 563.0, "ctrunc": 1.0, "ctrunclen": 51.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -1.8318112893211946e-07]], "metrics": {"clen": 204.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date", "id": "GoogleSGD_ApiDescription", "domain": "Movies, Services", "episode_done": false, "prefix_stripped_text": "api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 70.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindProvider ; city = Santa Rosa", "id": "GoogleSGD_Goal", "domain": "Movies, Services", "episode_done": false, "prefix_stripped_text": "api_name = FindProvider ; city = Santa Rosa"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a dentist.", "beam_texts": [["USER: I need to find a dentist.", -2.2936532497406006]], "metrics": {"clen": 37.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a dentist."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -4.657026147469878e-05]], "metrics": {"clen": 156.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: In which city?", "beam_texts": [["SYSTEM: In which city?", -1.2308539152145386]], "metrics": {"clen": 169.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "In which city?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a dentist in Santa Rosa.", "beam_texts": [["USER: I need a dentist in Santa Rosa.", -1.99648118019104]], "metrics": {"clen": 94.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a dentist in Santa Rosa."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindProvider ; city = Santa Rosa", "beam_texts": [["APICALL: api_name = FindProvider ; city = Santa Rosa", -0.00010167308937525377]], "metrics": {"clen": 196.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindProvider ; city = Santa Rosa"}, {"text": "APIRESP: address = 4812, 1144 Sonoma Avenue ; city = Santa Rosa ; dentist_name = Access Dental - Santa Rosa ; offers_cosmetic_services = False ; phone_number = 866-682-9904", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "address = 4812, 1144 Sonoma Avenue ; city = Santa Rosa ; dentist_name = Access Dental - Santa Rosa ; offers_cosmetic_services = False ; phone_number = 866-682-9904"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 dentists. How about Access Dental - Santa Rosa in Santa Rosa?", "beam_texts": [["SYSTEM: I found 10 dentists. How about Access Dental - Santa Rosa in Santa Rosa?", -1.664900541305542]], "metrics": {"clen": 242.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 dentists. How about Access Dental - Santa Rosa in Santa Rosa?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that sounds good.", "beam_texts": [["USER: Yes, that sounds good.", -2.585338592529297]], "metrics": {"clen": 115.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00035133821074850857]], "metrics": {"clen": 265.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Would you like to book an appointment?", "beam_texts": [["SYSTEM: Would you like to book an appointment?", -1.2300727367401123]], "metrics": {"clen": 278.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Would you like to book an appointment?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not right now.", "beam_texts": [["USER: No, not right now.", -1.776436448097229]], "metrics": {"clen": 162.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not right now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0005406867130659521]], "metrics": {"clen": 406.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Is there anything else I can help you with?", "beam_texts": [["SYSTEM: Is there anything else I can help you with?", -0.7443912625312805]], "metrics": {"clen": 311.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8158944845199585]], "metrics": {"clen": 179.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -3.9451406337320805e-05]], "metrics": {"clen": 335.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.9284720420837402]], "metrics": {"clen": 396.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -2.289761141582858e-06]], "metrics": {"clen": 194.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ReserveHotel ; optArg = number_of_rooms ; reqArg = check_in_date, location, place_name, stay_length | api_name = SearchHotel ; optArg = number_of_rooms, smoking_allowed, star_rating ; reqArg = location", "id": "GoogleSGD_ApiDescription", "domain": "Events, Hotels", "episode_done": false, "prefix_stripped_text": "api_name = FindEvents ; optArg = date ; reqArg = city, event_type | api_name = BuyEventTickets ; reqArg = city, date, event_name, number_of_tickets | api_name = ReserveHotel ; optArg = number_of_rooms ; reqArg = check_in_date, location, place_name, stay_length | api_name = SearchHotel ; optArg = number_of_rooms, smoking_allowed, star_rating ; reqArg = location"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 66.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2", "id": "GoogleSGD_Goal", "domain": "Events, Hotels", "episode_done": false, "prefix_stripped_text": "api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to buy tickets to an event.", "beam_texts": [["USER: I want to buy tickets to an event.", -2.4751245975494385]], "metrics": {"clen": 43.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to buy tickets to an event."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -8.394322503590956e-05]], "metrics": {"clen": 147.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What city should I look in?", "beam_texts": [["SYSTEM: What city should I look in?", -3.1570680141448975]], "metrics": {"clen": 160.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What city should I look in?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to find something in New York City.", "beam_texts": [["USER: I want to find something in New York City.", -2.8168575763702393]], "metrics": {"clen": 111.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to find something in New York City."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000529915967490524]], "metrics": {"clen": 186.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date would you like to go?", "beam_texts": [["SYSTEM: What date would you like to go?", -2.9322915077209473]], "metrics": {"clen": 400.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date would you like to go?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to go on the 14th of this month.", "beam_texts": [["USER: I want to go on the 14th of this month.", -1.2178394794464111]], "metrics": {"clen": 138.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to go on the 14th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0003908717480953783]], "metrics": {"clen": 431.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: How many tickets do you want?", "beam_texts": [["SYSTEM: How many tickets do you want?", -1.6059118509292603]], "metrics": {"clen": 444.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "How many tickets do you want?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want two tickets.", "beam_texts": [["USER: I want two tickets.", -1.3766371011734009]], "metrics": {"clen": 165.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want two tickets."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000701539684087038]], "metrics": {"clen": 465.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What event are you interested in?", "beam_texts": [["SYSTEM: What event are you interested in?", -1.3350071907043457]], "metrics": {"clen": 478.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What event are you interested in?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to see Alpha Blondy.", "beam_texts": [["USER: I want to see Alpha Blondy.", -1.0148762464523315]], "metrics": {"clen": 123.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to see Alpha Blondy."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -4.133000038564205e-05]], "metrics": {"clen": 500.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: 2 tickets to Alpha Blondy on March 14th in New York.", "beam_texts": [["SYSTEM: Please confirm: 2 tickets to Alpha Blondy on March 14th in New York.", -1.1921658515930176]], "metrics": {"clen": 405.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Please confirm: 2 tickets to Alpha Blondy on March 14th in New York."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that's great. Where is the venue?", "beam_texts": [["USER: Yes, that's great. Where is the venue?", -2.867255926132202]], "metrics": {"clen": 166.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that's great. Where is the venue?"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2", "beam_texts": [["APICALL: api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2", -1.1100361916760448e-05]], "metrics": {"clen": 401.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = BuyEventTickets ; city = New York ; date = 2019-03-14 ; event_name = Alpha Blondy ; number_of_tickets = 2"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Sorry, I couldn't book the tickets. Can I help with anything else?", "beam_texts": [["SYSTEM: Sorry, I couldn't book the tickets. Can I help with anything else?", -2.606656551361084]], "metrics": {"clen": 414.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Sorry, I couldn't book the tickets. Can I help with anything else?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.8586272597312927]], "metrics": {"clen": 261.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00046617790940217674]], "metrics": {"clen": 614.0, "ctrunc": 1.0, "ctrunclen": 102.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.6713320016860962]], "metrics": {"clen": 627.0, "ctrunc": 1.0, "ctrunclen": 115.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -3.663622294425295e-07]], "metrics": {"clen": 260.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = ShareLocation ; reqArg = contact_name, location | api_name = GetAlarms | api_name = AddAlarm ; optArg = new_alarm_name ; reqArg = new_alarm_time", "id": "GoogleSGD_ApiDescription", "domain": "Alarm, Messaging, Services", "episode_done": false, "prefix_stripped_text": "api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city | api_name = ShareLocation ; reqArg = contact_name, location | api_name = GetAlarms | api_name = AddAlarm ; optArg = new_alarm_name ; reqArg = new_alarm_time"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", 0.0]], "metrics": {"clen": 70.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "id": "GoogleSGD_Goal", "domain": "Alarm, Messaging, Services", "episode_done": false, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a salon.", "beam_texts": [["USER: I need to find a salon.", -2.4835095405578613]], "metrics": {"clen": 34.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a salon."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -7.652808562852442e-05]], "metrics": {"clen": 156.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: In which city?", "beam_texts": [["SYSTEM: In which city?", -1.3473020792007446]], "metrics": {"clen": 169.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "In which city?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a unisex salon in San Jose.", "beam_texts": [["USER: I need a unisex salon in San Jose.", -1.449689507484436]], "metrics": {"clen": 88.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a unisex salon in San Jose."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "beam_texts": [["APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", -3.9753711462253705e-05]], "metrics": {"clen": 197.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}, {"text": "APIRESP: average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", "beam_texts": [["SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", -1.7336264848709106]], "metrics": {"clen": 199.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 salons. Atelier Salon Willow Glen is in San Jose."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -2.342952251434326]], "metrics": {"clen": 117.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00011350425484124571]], "metrics": {"clen": 221.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to book an appointment?", "beam_texts": [["SYSTEM: Do you want to book an appointment?", -1.1533632278442383]], "metrics": {"clen": 234.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Do you want to book an appointment?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.7858071327209473]], "metrics": {"clen": 147.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.001105411909520626]], "metrics": {"clen": 261.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Anything else I can help with?", "beam_texts": [["SYSTEM: Anything else I can help with?", -1.460990309715271]], "metrics": {"clen": 369.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Anything else I can help with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.6900044679641724]], "metrics": {"clen": 164.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0001660985144553706]], "metrics": {"clen": 384.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.5248937010765076]], "metrics": {"clen": 369.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -7.327242883548024e-07]], "metrics": {"clen": 153.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date | api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date | api_name = GetWeather ; optArg = date ; reqArg = city", "id": "GoogleSGD_ApiDescription", "domain": "Homes, RentalCars, Weather", "episode_done": false, "prefix_stripped_text": "api_name = GetCarsAvailable ; optArg = car_type ; reqArg = city, end_date, pickup_time, start_date | api_name = ReserveCar ; reqArg = add_insurance, car_type, end_date, pickup_location, pickup_time, start_date | api_name = FindHomeByArea ; optArg = has_garage, in_unit_laundry ; reqArg = area, intent, number_of_baths, number_of_beds | api_name = ScheduleVisit ; reqArg = property_name, visit_date | api_name = GetWeather ; optArg = date ; reqArg = city"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 91.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: add_insurance = True ; api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30 ; start_date = 2019-03-09", "id": "GoogleSGD_Goal", "domain": "Homes, RentalCars, Weather", "episode_done": false, "prefix_stripped_text": "add_insurance = True ; api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30 ; start_date = 2019-03-09"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to reserve a car.", "beam_texts": [["USER: I need to reserve a car.", -2.8472073078155518]], "metrics": {"clen": 20.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to reserve a car."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00011637131683528423]], "metrics": {"clen": 197.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What time do you want to pick it up?", "beam_texts": [["SYSTEM: What time do you want to pick it up?", -3.0318820476531982]], "metrics": {"clen": 210.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What time do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to pick it up at 3:30 pm.", "beam_texts": [["USER: I want to pick it up at 3:30 pm.", -2.287160634994507]], "metrics": {"clen": 68.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to pick it up at 3:30 pm."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0001661973656155169]], "metrics": {"clen": 229.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to pick it up?", "beam_texts": [["SYSTEM: What date do you want to pick it up?", -2.177884578704834]], "metrics": {"clen": 242.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to pick it up on the 9th of this month.", "beam_texts": [["USER: I want to pick it up on the 9th of this month.", -1.9553515911102295]], "metrics": {"clen": 109.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to pick it up on the 9th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00047072372399270535]], "metrics": {"clen": 265.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where do you want to pick it up?", "beam_texts": [["SYSTEM: Where do you want to pick it up?", -1.259605050086975]], "metrics": {"clen": 278.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where do you want to pick it up?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to pick it up at PDX International Airport.", "beam_texts": [["USER: I want to pick it up at PDX International Airport.", -0.8431561589241028]], "metrics": {"clen": 126.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to pick it up at PDX International Airport."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.008902691304683685]], "metrics": {"clen": 298.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to return the car?", "beam_texts": [["SYSTEM: What date do you want to return the car?", -1.4374130964279175]], "metrics": {"clen": 316.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to return the car?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to return it on the 14th of this month.", "beam_texts": [["USER: I want to return it on the 14th of this month.", -1.5039756298065186]], "metrics": {"clen": 175.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to return it on the 14th of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.023951882496476173]], "metrics": {"clen": 349.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: You want to pick up a standard car from PDX International Airport on March 9th at 3:30 pm. You will return it on March 14th.", "beam_texts": [["SYSTEM: Please confirm: You want to pick up a standard car from PDX International Airport on March 9th at 3:30 pm. You will return it on March 14th.", -2.0667858123779297]], "metrics": {"clen": 397.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Please confirm: You want to pick up a standard car from PDX International Airport on March 9th at 3:30 pm. You will return it on March 14th."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, I want a Hatchback.", "beam_texts": [["USER: No, I want a Hatchback.", -1.684838056564331]], "metrics": {"clen": 226.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, I want a Hatchback."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -2.9663365239684936e-06]], "metrics": {"clen": 738.0, "ctrunc": 1.0, "ctrunclen": 226.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Please confirm: You want to pick up a Hatchback from PDX International Airport.", "beam_texts": [["SYSTEM: Please confirm: You want to pick up a Hatchback from PDX International Airport.", -0.6164475083351135]], "metrics": {"clen": 751.0, "ctrunc": 1.0, "ctrunclen": 239.0}, "prefix_stripped_text": "Please confirm: You want to pick up a Hatchback from PDX International Airport."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that is correct.", "beam_texts": [["USER: Yes, that is correct.", -0.887704610824585]], "metrics": {"clen": 244.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that is correct."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30", "beam_texts": [["APICALL: api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30", -0.011755380779504776]], "metrics": {"clen": 770.0, "ctrunc": 1.0, "ctrunclen": 258.0}, "prefix_stripped_text": "api_name = ReserveCar ; car_type = Hatchback ; end_date = 2019-03-14 ; pickup_location = PDX International Airport ; pickup_time = 15:30"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I was unable to reserve the car. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I was unable to reserve the car. Is there anything else I can help you with?", -1.8006048202514648]], "metrics": {"clen": 783.0, "ctrunc": 1.0, "ctrunclen": 271.0}, "prefix_stripped_text": "I was unable to reserve the car. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, that is all.", "beam_texts": [["USER: No, that is all.", -1.2942805290222168]], "metrics": {"clen": 266.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, that is all."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0007116166525520384]], "metrics": {"clen": 805.0, "ctrunc": 1.0, "ctrunclen": 293.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a good day.", "beam_texts": [["SYSTEM: Have a good day.", -0.58097243309021]], "metrics": {"clen": 818.0, "ctrunc": 1.0, "ctrunclen": 306.0}, "prefix_stripped_text": "Have a good day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -5.495433015312301e-07]], "metrics": {"clen": 283.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = BookAppointment ; reqArg = appointment_date, appointment_time, therapist_name | api_name = FindProvider ; reqArg = city, type | api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city", "id": "GoogleSGD_ApiDescription", "domain": "Services, Services", "episode_done": false, "prefix_stripped_text": "api_name = BookAppointment ; reqArg = appointment_date, appointment_time, therapist_name | api_name = FindProvider ; reqArg = city, type | api_name = BookAppointment ; reqArg = appointment_date, appointment_time, stylist_name | api_name = FindProvider ; optArg = is_unisex ; reqArg = city"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 99.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "id": "GoogleSGD_Goal", "domain": "Services, Services", "episode_done": false, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need to find a salon.", "beam_texts": [["USER: I need to find a salon.", -2.4835095405578613]], "metrics": {"clen": 22.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need to find a salon."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.00014603004092350602]], "metrics": {"clen": 213.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: In which city?", "beam_texts": [["SYSTEM: In which city?", -1.5344370603561401]], "metrics": {"clen": 226.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "In which city?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I need a unisex salon in San Jose.", "beam_texts": [["USER: I need a unisex salon in San Jose.", -1.449689507484436]], "metrics": {"clen": 87.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I need a unisex salon in San Jose."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", "beam_texts": [["APICALL: api_name = FindProvider ; city = San Jose ; is_unisex = True", -0.0007448627729900181]], "metrics": {"clen": 248.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindProvider ; city = San Jose ; is_unisex = True"}, {"text": "APIRESP: average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": "average_rating = 4.30 ; city = San Jose ; is_unisex = True ; phone_number = 408-293-3494 ; street_address = 1047 Lincoln Highway suite 101 ; stylist_name = Atelier Salon Willow Glen"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", "beam_texts": [["SYSTEM: I found 10 salons. Atelier Salon Willow Glen is in San Jose.", -1.8086758852005005]], "metrics": {"clen": 210.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I found 10 salons. Atelier Salon Willow Glen is in San Jose."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: That sounds good.", "beam_texts": [["USER: That sounds good.", -2.342952251434326]], "metrics": {"clen": 113.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "That sounds good."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.0001382200134685263]], "metrics": {"clen": 236.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Do you want to book an appointment?", "beam_texts": [["SYSTEM: Do you want to book an appointment?", -1.16355562210083]], "metrics": {"clen": 249.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Do you want to book an appointment?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, not now.", "beam_texts": [["USER: No, not now.", -1.7858071327209473]], "metrics": {"clen": 137.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, not now."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000978908152319491]], "metrics": {"clen": 273.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Anything else I can help you with?", "beam_texts": [["SYSTEM: Anything else I can help you with?", -1.4312176704406738]], "metrics": {"clen": 286.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.6700665950775146]], "metrics": {"clen": 177.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -6.713548646075651e-05]], "metrics": {"clen": 319.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a nice day.", "beam_texts": [["SYSTEM: Have a nice day.", -0.659913957118988]], "metrics": {"clen": 513.0, "ctrunc": 1.0, "ctrunclen": 1.0}, "prefix_stripped_text": "Have a nice day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -7.327242883548024e-07]], "metrics": {"clen": 287.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = ReserveRestaurant ; optArg = date, number_of_seats ; reqArg = location, restaurant_name, time | api_name = FindRestaurants ; optArg = has_seating_outdoors, has_vegetarian_options, price_range ; reqArg = category, location | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date", "id": "GoogleSGD_ApiDescription", "domain": "Movies, Restaurants", "episode_done": false, "prefix_stripped_text": "api_name = ReserveRestaurant ; optArg = date, number_of_seats ; reqArg = location, restaurant_name, time | api_name = FindRestaurants ; optArg = has_seating_outdoors, has_vegetarian_options, price_range ; reqArg = category, location | api_name = BuyMovieTickets ; reqArg = location, movie_name, number_of_tickets, show_date, show_time, show_type | api_name = FindMovies ; optArg = genre, show_type, theater_name ; reqArg = location | api_name = GetTimesForMovie ; optArg = show_type, theater_name ; reqArg = location, movie_name, show_date"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -2.1568732222476683e-07]], "metrics": {"clen": 186.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter Two ; show_date = 2019-03-01", "id": "GoogleSGD_Goal", "domain": "Movies, Restaurants", "episode_done": false, "prefix_stripped_text": "api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter Two ; show_date = 2019-03-01"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch a movie.", "beam_texts": [["USER: I want to watch a movie.", -3.0537710189819336]], "metrics": {"clen": 14.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch a movie."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.002036166610196233]], "metrics": {"clen": 387.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Where are you located?", "beam_texts": [["SYSTEM: Where are you located?", -0.9682686924934387]], "metrics": {"clen": 400.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Where are you located?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch it in Castro Valley.", "beam_texts": [["USER: I want to watch it in Castro Valley.", -1.7995065450668335]], "metrics": {"clen": 105.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch it in Castro Valley."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = FindMovies ; location = Castro Valley", "beam_texts": [["APICALL: api_name = FindMovies ; location = Castro Valley", -0.000328321271808818]], "metrics": {"clen": 420.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = FindMovies ; location = Castro Valley"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I can't find any movies matching your criteria.", "beam_texts": [["SYSTEM: I can't find any movies matching your criteria.", -3.635852098464966]], "metrics": {"clen": 433.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I can't find any movies matching your criteria."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch IT Chapter Two.", "beam_texts": [["USER: I want to watch IT Chapter Two.", -2.210944652557373]], "metrics": {"clen": 130.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch IT Chapter Two."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.11667776852846146]], "metrics": {"clen": 456.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: What date do you want to watch it?", "beam_texts": [["SYSTEM: What date do you want to watch it?", -1.8153434991836548]], "metrics": {"clen": 469.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "What date do you want to watch it?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to watch it on the 1st of this month.", "beam_texts": [["USER: I want to watch it on the 1st of this month.", -1.2194650173187256]], "metrics": {"clen": 156.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to watch it on the 1st of this month."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter 2 ; show_date = 2019-03-01", "beam_texts": [["APICALL: api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter 2 ; show_date = 2019-03-01", -0.18431638181209564]], "metrics": {"clen": 304.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = GetTimesForMovie ; location = Castro Valley ; movie_name = IT Chapter 2 ; show_date = 2019-03-01"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I can't find any times for the movie. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I can't find any times for the movie. Is there anything else I can help you with?", -2.2572638988494873]], "metrics": {"clen": 503.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I can't find any times for the movie. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.9540809392929077]], "metrics": {"clen": 146.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000830456439871341]], "metrics": {"clen": 527.0, "ctrunc": 1.0, "ctrunclen": 15.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.8139710426330566]], "metrics": {"clen": 695.0, "ctrunc": 1.0, "ctrunclen": 183.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -3.663622294425295e-07]], "metrics": {"clen": 140.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} +{"dialog": [[{"text": "APIS: api_name = PlayMedia ; optArg = album, artist, device ; reqArg = track | api_name = LookupMusic ; optArg = album, artist, genre, year | api_name = GetWeather ; optArg = date ; reqArg = city", "id": "GoogleSGD_ApiDescription", "domain": "Music, Weather", "episode_done": false, "prefix_stripped_text": "api_name = PlayMedia ; optArg = album, artist, device ; reqArg = track | api_name = LookupMusic ; optArg = album, artist, genre, year | api_name = GetWeather ; optArg = date ; reqArg = city"}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APIS: ", "beam_texts": [["APIS: ", -1.0784366821781077e-07]], "metrics": {"clen": 75.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "APIS: "}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"text": "GOAL: album = Sucker Punch ; api_name = PlayMedia ; device = Kitchen ; track = Strangers", "id": "GoogleSGD_Goal", "domain": "Music, Weather", "episode_done": false, "prefix_stripped_text": "album = Sucker Punch ; api_name = PlayMedia ; device = Kitchen ; track = Strangers"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to listen to the song Strangers from the album Sucker Punch.", "beam_texts": [["USER: I want to listen to the song Strangers from the album Sucker Punch.", -2.6374406814575195]], "metrics": {"clen": 60.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to listen to the song Strangers from the album Sucker Punch."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: album = Sucker Punch ; api_name = PlayMedia ; track = Strangers", "beam_texts": [["APICALL: album = Sucker Punch ; api_name = PlayMedia ; track = Strangers", -0.5099499821662903]], "metrics": {"clen": 165.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "album = Sucker Punch ; api_name = PlayMedia ; track = Strangers"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't play the song. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't play the song. Is there anything else I can help you with?", -2.205378770828247]], "metrics": {"clen": 196.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't play the song. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: I want to play the song on the kitchen device.", "beam_texts": [["USER: I want to play the song on the kitchen device.", -2.291337013244629]], "metrics": {"clen": 46.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to play the song on the kitchen device."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -0.000290650554234162]], "metrics": {"clen": 359.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I want to play Strangers on the kitchen speaker.", "beam_texts": [["SYSTEM: I want to play Strangers on the kitchen speaker.", -2.0472002029418945]], "metrics": {"clen": 372.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I want to play Strangers on the kitchen speaker."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: Yes, that is correct.", "beam_texts": [["USER: Yes, that is correct.", -1.4469143152236938]], "metrics": {"clen": 151.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Yes, that is correct."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: api_name = PlayMedia ; album = Sucker Punch ; device = Kitchen speaker ; track = Strangers", "beam_texts": [["APICALL: api_name = PlayMedia ; album = Sucker Punch ; device = Kitchen speaker ; track = Strangers", -0.029085755348205566]], "metrics": {"clen": 330.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "api_name = PlayMedia ; album = Sucker Punch ; device = Kitchen speaker ; track = Strangers"}, {"text": "APIRESP: ", "id": "StandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: I couldn't play the song. Is there anything else I can help you with?", "beam_texts": [["SYSTEM: I couldn't play the song. Is there anything else I can help you with?", -0.39661797881126404]], "metrics": {"clen": 343.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "I couldn't play the song. Is there anything else I can help you with?"}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: No, thank you.", "beam_texts": [["USER: No, thank you.", -0.48353710770606995]], "metrics": {"clen": 176.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "No, thank you."}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "APICALL: ", "beam_texts": [["APICALL: ", -8.918321691453457e-05]], "metrics": {"clen": 393.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": ""}, {"text": "APIRESP: ", "id": "TodStandaloneApiAgent", "episode_done": false, "prefix_stripped_text": ""}, {"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "SYSTEM: Have a great day.", "beam_texts": [["SYSTEM: Have a great day.", -0.6688835620880127]], "metrics": {"clen": 376.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "Have a great day."}], [{"id": "R3fFirstTurnHistoryRepeat", "episode_done": false, "text": "USER: [DONE]", "beam_texts": [["USER: [DONE]", -9.159052751783747e-07]], "metrics": {"clen": 232.0, "ctrunc": 0.0, "ctrunclen": 0.0}, "prefix_stripped_text": "[DONE]"}]], "context": [], "metadata_path": "/checkpoint/mpchen/projects/taskoriented/user_generator/google_sgd/mm8_gsgd_bartR3f_protoTestTrainTue_May_11/withApi_conversations.metadata"} From 6cb4b865a080fa33e00e440f02a9d4cd8b0ed60c Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 15:50:00 -0800 Subject: [PATCH 04/23] [TOD] Core converesation structure, serialization, const tokens --- parlai/core/tod/tod_core.py | 227 ++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 parlai/core/tod/tod_core.py diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py new file mode 100644 index 00000000000..76ee8005a74 --- /dev/null +++ b/parlai/core/tod/tod_core.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) enums and base classes. + +This file defines standard tokens, classes for round and conversation structure, and a serialization class to aid in converting between these. + +See `tod_agents.py` for usage of these classes to generate training data and `tod_world_script.py` for usage of these classes in simulated conversations. +""" +from enum import Enum +from typing import List, Dict +from dataclasses import dataclass, field +from collections.abc import Iterable +from parlai.utils.misc import warn_once + +STANDARD_CALL = "APICALL: " +STANDARD_RESP = "APIRESP: " +STANDARD_SYSTEM_UTTERANCE = "SYSTEM: " +STANDARD_USER_UTTERANCE = "USER: " + +STANDARD_GOAL = "GOAL: " +STANDARD_API_SCHEMAS = "APIS: " + +STANDARD_API_NAME_SLOT = "api_name" +STANDARD_REQUIRED_KEY = "reqArg" +STANDARD_OPTIONAL_KEY = "optArg" +STANDARD_DONE = "[DONE]" + +CONST_SILENCE = "__SILENCE__" + + +class TodAgentType(str, Enum): + USER_UTT_AGENT = "user_utt_model" + API_CALL_AGENT = "api_call_model" + API_RESP_AGENT = "api_resp_model" + SYSTEM_UTT_AGENT = "system_utt_model" + API_SCHEMA_GROUNDING_AGENT = "api_schema_grounding_model" + GOAL_GROUNDING_AGENT = "goal_grounding_model" + + +TOD_AGENT_TYPE_TO_PREFIX = { + TodAgentType.USER_UTT_AGENT: STANDARD_USER_UTTERANCE, + TodAgentType.API_CALL_AGENT: STANDARD_CALL, + TodAgentType.API_RESP_AGENT: STANDARD_RESP, + TodAgentType.SYSTEM_UTT_AGENT: STANDARD_SYSTEM_UTTERANCE, + TodAgentType.API_SCHEMA_GROUNDING_AGENT: STANDARD_API_SCHEMAS, + TodAgentType.GOAL_GROUNDING_AGENT: STANDARD_GOAL, +} + + +@dataclass +class TodStructuredRound: + """ + Dataclass for rounds. + """ + + # Variables set by those using this class + user_utt: str = "" + api_call_machine: Dict = field( + default_factory=dict + ) # Hashmap of slot keys and slot values. Note that STANDARD_API_NAME_SLOT (`api_name`) is expected to be one of the keys here when this is nonempty; simulation metrics wonky without + api_resp_machine: Dict = field(default_factory=dict) + sys_utt: str = "" + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_call_utt: str = field(init=False) + api_resp_utt: str = field(init=False) + + def __post_init__(self): + self.api_call_utt = SerializationHelpers.api_dict_to_str(self.api_call_machine) + self.api_resp_utt = SerializationHelpers.api_dict_to_str(self.api_resp_machine) + if ( + len(self.api_call_machine) > 0 + and STANDARD_API_NAME_SLOT not in self.api_call_machine + ): + warn_once( + f"{STANDARD_API_NAME_SLOT} missing when API Call present. This may cause issues for simulation metrics." + ) + + +@dataclass +class TodStructuredEpisode: + """ + Dataclass for episode-level data. + """ + + # Variables set by those using this class + delex: bool = False # Set to true and this class will handle delexicalizing call + response utterances based on API calls and responses exposed to this class. + domain: str = "" # self-explanatory + api_schemas_machine: List[Dict[str, List]] = field( + default_factory=list + ) # Expected to be a List of Dicts with the API name, required arguments, and optional arguments (specified by consts at the top of this file) as keys + goal_calls_machine: List[Dict[str, str]] = field( + default_factory=list + ) # Machine-formatted API calls + rounds: List[TodStructuredRound] = field(default_factory=list) # self explanatory + extras: Dict = field( + default_factory=dict + ) # Grab bag for extra data. Not currently referenced in any TOD core code, but a convenient leaky abstraction for passing dataset-specific data between Parser classes and realized agents/teachers. + + # Variables derived by class + api_schemas_utt: str = field(init=False) + goal_calls_utt: str = field(init=False) + + def __post_init__(self): + self.api_schemas_utt = SerializationHelpers.list_of_maps_to_str( + self.api_schemas_machine + ) + self.goal_calls_machine = [ + call for call in self.goal_calls_machine if len(call) > 0 + ] + self.goal_calls_utt = SerializationHelpers.list_of_maps_to_str( + self.goal_calls_machine + ) + # Add a done turn at the end + self.rounds.append(TodStructuredRound(user_utt=STANDARD_DONE)) + if self.delex: + accum_slots = ( + {} + ) # separate since some slot values change as we go. Use this for delex first + cum_slots = self.get_all_slots() + for r in self.rounds: + accum_slots.update(r.api_call_machine) + accum_slots.update(r.api_resp_machine) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, accum_slots) + r.sys_utt = SerializationHelpers.delex(r.sys_utt, cum_slots) + + def get_all_slots(self): + result = {} + for r in self.rounds: + result.update(r.api_call_machine) + result.update(r.api_resp_machine) + return result + + +class SerializationHelpers: + @classmethod + def delex(cls, text, slots): + delex = text + for slot, value in slots.items(): + if isinstance(value, str): + delex = delex.replace(value, f"[{slot}]") + else: + for v in value: + delex = delex.replace(v, f"[{slot}]") + return delex + + @classmethod + def inner_list_join(cls, values): + if isinstance(values, str): + return values + return ", ".join(sorted([v.strip() for v in values])) + + @classmethod + def inner_list_split(cls, s): + return s.split(", ") + + @classmethod + def maybe_inner_list_join(cls, values): + if isinstance(values, str) or isinstance(values, int): + return values + elif isinstance(values, Iterable): + return SerializationHelpers.inner_list_join(values) + else: + raise RuntimeError("invalid type of argument for maybe_inner_list_join") + + @classmethod + def api_dict_to_str(cls, apidict): + """ + Used for API Calls and Responses -> Utterance. + """ + return " ; ".join( + f"{k} = {SerializationHelpers.maybe_inner_list_join(v)}" + for k, v in sorted(apidict.items()) + ) + + @classmethod + def str_to_api_dict(cls, string): + """ + Used for API Call and Response Utterances -> Dict. + """ + slot_strs = string.split(" ; ") + result = {} + for slot_str in slot_strs: + if " = " not in slot_str: + continue + name, value = slot_str.split(" = ", 1) + name = name.strip() + value = value.strip() + result[name] = value + return result + + @classmethod + def outer_list_join(cls, s): + return " | ".join(s) + + @classmethod + def outer_list_split(cls, s): + return s.split(" | ") + + @classmethod + def str_to_list_of_maps(cls, s): + return [ + SerializationHelpers.str_to_api_dict(x) + for x in SerializationHelpers.outer_list_split(s) + ] + + @classmethod + def list_of_maps_to_str(cls, list_of_maps): + return SerializationHelpers.outer_list_join( + [SerializationHelpers.api_dict_to_str(m) for m in list_of_maps] + ) + + @classmethod + def str_to_goals(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) + + @classmethod + def str_to_api_schemas(cls, s): # convenience + return SerializationHelpers.str_to_list_of_maps(s) From 1480defd9fe6dacfeac287cb64c7b1e3436a236f Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 10:08:40 -0800 Subject: [PATCH 05/23] fix test by adding init folder --- parlai/core/tod/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 parlai/core/tod/__init__.py diff --git a/parlai/core/tod/__init__.py b/parlai/core/tod/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/core/tod/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. From de84801e3b47d9555277b03552835f75f2c529cd Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 20:21:58 -0800 Subject: [PATCH 06/23] [Tod] Agents, teacher metrics, and tests for these See documentation block in `tod_agents.py` (I'm not 100% sure if `conftest.py` is a right file to change, though I did notice that `pytest.ini` was necessary to get pytest to run.) --- conftest.py | 1 + parlai/core/tod/teacher_metrics.py | 159 ++++ parlai/core/tod/tod_agents.py | 794 ++++++++++++++++++ parlai/core/tod/tod_test_utils/__init__.py | 5 + parlai/core/tod/tod_test_utils/test_agents.py | 216 +++++ pytest.ini | 1 + tests/tod/__init__.py | 5 + tests/tod/test_tod_agents_and_teachers.py | 327 ++++++++ tests/tod/test_tod_teacher_metrics.py | 74 ++ 9 files changed, 1582 insertions(+) create mode 100644 parlai/core/tod/teacher_metrics.py create mode 100644 parlai/core/tod/tod_agents.py create mode 100644 parlai/core/tod/tod_test_utils/__init__.py create mode 100644 parlai/core/tod/tod_test_utils/test_agents.py create mode 100644 tests/tod/__init__.py create mode 100644 tests/tod/test_tod_agents_and_teachers.py create mode 100644 tests/tod/test_tod_teacher_metrics.py diff --git a/conftest.py b/conftest.py index 7cc1e262461..8970273a460 100644 --- a/conftest.py +++ b/conftest.py @@ -67,6 +67,7 @@ def filter_tests_with_circleci(test_list): ('datatests/', 'data'), ('parlai/tasks/', 'teacher'), ('tasks/', 'tasks'), + ('tod/', 'tod'), ] diff --git a/parlai/core/tod/teacher_metrics.py b/parlai/core/tod/teacher_metrics.py new file mode 100644 index 00000000000..3fc85c7d107 --- /dev/null +++ b/parlai/core/tod/teacher_metrics.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Task Oriented Dialogue (TOD) teacher metrics. +""" +from typing import Optional, List, Dict, Any +from parlai.core.metrics import AverageMetric, BleuMetric, F1Metric, Metric, Metrics + + +class SlotMetrics(Metrics): + """ + Helper container which encapsulates standard slot metrics in task oriented learning + (jga, slot_p, slot_r, etc). + + Due to differences in dialogue representations between tasks, the input is pre- + parsed ground truth and predicted slot dictionaries. + + The 'jga+nlg' metric assumes a balanced set of JGA and NLG scores such that + 2 * Avg(JGA, NLG_BLEU) = Avg(JGA + NLG_BLEU) + The `jga+nlg` metric assumes that `NlgMetrics` is used to calculated the other side. + """ + + def __init__( + self, + teacher_slots: Dict[str, str], + predicted_slots: Dict[str, str], + prefixes: Optional[List] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + # jga and optionally Avg(jga,nlg_bleu) + self.add_with_prefixes("jga", AverageMetric(teacher_slots == predicted_slots)) + if len(teacher_slots) > 0: + self.add_with_prefixes( + "jga_noempty", AverageMetric(teacher_slots == predicted_slots) + ) + else: + self.add_with_prefixes( + "jga_empty", AverageMetric(teacher_slots == predicted_slots) + ) + + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), NlgMetrics class (below) adds NLG-BLEU + self.add("jga+nlg", AverageMetric(teacher_slots == predicted_slots)) + # precision + for pred_slot_name, pred_value in predicted_slots.items(): + slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value) + self.add_with_prefixes("slot_p", slot_p) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_p=slot_p)) + # recall + for teacher_slot_name, teacher_value in teacher_slots.items(): + slot_r = AverageMetric( + predicted_slots.get(teacher_slot_name) == teacher_value + ) + self.add_with_prefixes("slot_r", slot_r) + self.add_with_prefixes("slot_f1", SlotF1Metric(slot_r=slot_r)) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +class NlgMetrics(Metrics): + """ + Helper container for generation version of standard metrics (F1, BLEU, ..). + """ + + def __init__( + self, + guess: str, + labels: Optional[List[str]], + prefixes: Optional[List[str]] = None, + shared: Dict[str, Any] = None, + avg_jga_nlg_bleu: bool = False, + ) -> None: + super().__init__(shared=shared) + self.prefixes = prefixes if prefixes else [] + bleu = BleuMetric.compute(guess, labels) + f1 = F1Metric.compute(guess, labels) + self.add_with_prefixes("nlg_bleu", bleu) + self.add_with_prefixes("nlg_f1", f1) + if avg_jga_nlg_bleu: + # add one half of Avg(jga,nlg_bleu), SlotMetrics class (above) adds JGA + self.add("jga+nlg", bleu) + + def add_with_prefixes(self, name, value): + self.add(name, value) + for prefix in self.prefixes: + self.add(f"{prefix}/{name}", value) + + +AverageType = Optional[AverageMetric] + + +def _average_type_sum_helper(first: AverageType, second: AverageType) -> AverageType: + """ + Helper to deal with Nones. + + We are "clever" in how we aggregate SlotF1Metrics (See SlotMetrics `__init__`) in + that we add precision and recall values separately, but this means we need to handle + None. + """ + if first is None: + return second + if second is None: + return first + return first + second + + +class SlotF1Metric(Metric): + """ + Metric to keep track of slot F1. + + Keeps track of slot precision and slot recall as running metrics. + """ + + __slots__ = ("_slot_p", "_slot_r") + + @property + def macro_average(self) -> bool: + """ + Indicates whether this metric should be macro-averaged when globally reported. + """ + return True + + def __init__(self, slot_p: AverageType = None, slot_r: AverageType = None): + if not isinstance(slot_p, AverageMetric) and slot_p is not None: + slot_p = AverageMetric(slot_p) + if not isinstance(slot_r, AverageMetric) and slot_r is not None: + slot_r = AverageMetric(slot_r) + self._slot_p = slot_p + self._slot_r = slot_r + + def __add__(self, other: Optional["SlotF1Metric"]) -> "SlotF1Metric": + # NOTE: hinting can be cleaned up with "from __future__ import annotations" when + # we drop Python 3.6 + if other is None: + return self + slot_p = _average_type_sum_helper(self._slot_p, other._slot_p) + slot_r = _average_type_sum_helper(self._slot_r, other._slot_r) + return type(self)(slot_p=slot_p, slot_r=slot_r) + + def value(self) -> float: + if self._slot_p is None or self._slot_r is None: + return float("nan") + else: + slot_p = self._slot_p.value() + slot_r = self._slot_r.value() + if slot_p == 0.0 and slot_r == 0.0: + return float("nan") + else: + return 2 * (slot_p * slot_r) / (slot_p + slot_r) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py new file mode 100644 index 00000000000..f22a8330760 --- /dev/null +++ b/parlai/core/tod/tod_agents.py @@ -0,0 +1,794 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Agents (used for dumping data) and Teachers (for training models) related to the TOD +conversation setup. + +# Usage + +For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. + +For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. + +Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for usage in generating simulations. + +As a convention, agents and teachers that are inheritable are prefixed with "Tod" whereas those that can be used as-is are not. Similarly, classes and functions that do not need to be exposed outside of this file are prefixed with a single underscore ('_'). + +## Why we do this +These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. + +# Description of Agents + Teachers useful for Simulation +## Teachers for training (generative) models + * TodSystemTeacher + * TodUserSimulatorTeacher + +## Agents for Grounding +For goal grounding for the User for simulation: + * TodGoalAgent + * TodSingleGoalAgent + +For (optional) API schema grounding for the System: + * TodApiSchemaAgent (must be used with `TodGoalAgent` only) + * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) + * EmptyApiSchemaAgent + * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. + +## Agents for mocking APIs: + * StandaloneApiAgent + * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher` + +# Agents for dumping data from a ground truth dataset +The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. + * TodApiCallAndSysUttAgent + * TodApiResponseAgent + * TodUserUttAgent + +For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. + +# Other agents +There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. +""" + +from parlai.core.agents import Agent +from parlai.core.message import Message +from parlai.core.metrics import AverageMetric +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.core.teachers import DialogTeacher +from parlai.utils.distributed import is_distributed, get_rank, num_workers + +import parlai.core.tod.tod_core as tod +from parlai.core.tod.tod_core import SerializationHelpers +from parlai.core.tod.teacher_metrics import SlotMetrics, NlgMetrics + +from typing import Optional, List +import json +import pickle +import difflib +import random +from math import ceil + + +######### Agents that dump information from a dataset; base classes +class TodStructuredDataParser(Agent): + """ + Base class that specifies intermediate representations for Tod conversations. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + if hasattr(super(), "add_cmdline_args"): + parser = super().add_cmdline_args(parser, partial_opt) + group = parser.add_argument_group("TOD StructuredData agent") + group.add_argument( + "--episodes-randomization-seed", + type=int, + default=-1, + help="Randomize episodes in a predictable way (eg, for few shot). Set to -1 for no randomization. ", + ) + parser.add_argument( + "--n-shot", + default=-1, + type=int, + help="Number of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--percent-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + parser.add_argument( + "--percent-shot", + default=-1, + type=float, + help="Percentage of dialogues to keep for each of train/valid/test. -1 means all. Dialogues of lower numbers are strict subsets of larger numbers. Do not use in conjunction with `--n-shot`. Use `--episodes-randomization-seed` to change seed. NOTE: Beware of using this flag when multitasking as this will apply to *all* datasets unless the ':' syntax for specifying per-dataset flags is used.", + ) + return parser + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.id = self.get_id_task_prefix() + "_" + self._get_agent_type_suffix() + if shared is None: + self.episodes = self.generate_episodes() + else: + self.episodes = shared["episodes"] + + def share(self): + share = super().share() + share["episodes"] = self.episodes + return share + + def setup_episodes(self, fold: str) -> List[tod.TodStructuredEpisode]: + """ + Fold here is a data fold. + """ + raise NotImplementedError( + "Must have method for generating an episode. Must be set in downstream Parser for a given task" + ) + + def generate_episodes(self) -> List[tod.TodStructuredEpisode]: + if self.opt.get("n_shot", -1) >= 0 and self.opt.get("percent_shot", -1) >= 0: + # Validate before spending a while to load eeverything + raise RuntimeError("Both `--n-shot` and `--percent-shot` in use!") + episodes = list(self.setup_episodes(self.fold)) + if self.opt.get("episodes_randomization_seed", -1) != -1: + random.Random(self.opt["episodes_randomization_seed"]).shuffle(episodes) + if self.opt.get("n_shot", -1) != -1: + episodes = episodes[: self.opt["n_shot"]] + elif self.opt.get("percent_shot", -1) >= 0: + episodes = episodes[: int(len(episodes) * self.opt["percent_shot"])] + return episodes + + def get_id_task_prefix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set ID prefix in downstream task agent. Must be set in downsream Parser for a given task" + ) + + def _get_agent_type_suffix(self) -> str: + """ + Convenience for setting IDs. + """ + raise NotImplementedError( + "Must set in downstream agent within `tod_agents`. If you see this error, something is wrong with TOD Infrastructure" + ) + + +######### Agents that dump information from a dataset as gold (explicitly should *not* be used with teachers) +class _TodDataDumpAgent(TodStructuredDataParser): + """ + For agents which dump data from some dataset, without training/other modifications. + + Implements an "epoch done" + + Member variables assumed to be set in init downstream: + self.fold + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + self.batchsize = opt.get("batchsize", 1) + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + self.episode_idx = opt.get("batchindex", 0) + self._setup_next_episode() + self.round_idx = 0 # for some downstream utt + sysUttAndApiCallAgents. + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.episode_idx += rank * chunk_size + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + def _setup_next_episode(self): + self.epochDone = not self.episode_idx < self.max_episodes + self.episode = None + if not self.epochDone: + self.episode = self.episodes[self.episode_idx] + self.round_idx = ( + 0 # so downstream agents know which round they are in. Update in `act()` + ) + + def epoch_done(self) -> bool: + return self.epochDone + + def episode_done(self) -> bool: + """ + This is not actually "episode_done" so much as "we want to signify to the world + that we have gone past the batch". + + This class should not control whether or not the episode is actually done since + the TodWorld expects that to come from the User agent. + """ + return self.epochDone + + def num_episodes(self) -> int: + return len(self.episodes) + + def reset(self): + self.episode_idx += self.batchsize + self._setup_next_episode() + + +class TodGoalAgent(_TodDataDumpAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + """ + + def act(self): + return { + "text": f"{tod.STANDARD_GOAL}{self.episode.goal_calls_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "Goal" + + +class TodApiSchemaAgent(_TodDataDumpAgent): + def act(self): + return { + "text": f"{tod.STANDARD_API_SCHEMAS}{self.episode.api_schemas_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + + def _get_agent_type_suffix(self): + return "ApiSchema" + + +############# Single Goal + Api Schema Agent +class _EpisodeToSingleGoalProcessor(_TodDataDumpAgent): + """ + Iterate through all of the goals of a dataset, one by one. + + Slightly different logic than the dump agent since how we count + setup examples for + an episode are different + + Used as a mixin in the SingleGoal and SingleApiSchema agents below. + + This class exposes a `filter_goals()` function that can be overridden by downstream agents. + """ + + def __init__(self, opt: Opt, shared=None): + super().__init__(opt, shared) + self.epochDone = False + if shared is None: + self.episodes = self._setup_single_goal_episodes() + else: + # Handled fine in _TodDataDumpAgent + pass + + self.max_episodes = len(self.episodes) + if opt.get("num_episodes", 0) > 0: + self.max_episodes = min(self.max_episodes, opt.get("num_episodes")) + if is_distributed(): # cause gotta manually handle + rank = get_rank() + chunk_size = ceil(self.max_episodes / num_workers()) + self.max_episodes = min(self.max_episodes, (rank + 1) * chunk_size) + + self._setup_next_episode() + + def _setup_single_goal_episodes(self) -> List[tod.TodStructuredEpisode]: + """ + This function assumes that `self.setup_episodes()` has already been called + prior. + + Based on the `__init__` order of this class, it should be done in + `TodStructuredDataParser` by this point. + """ + raw_episodes = self.episodes + result = [] + for raw in raw_episodes: + for call in self.filter_goals(raw.goal_calls_machine): + schema = {} + for cand in raw.api_schemas_machine: + if ( + cand[tod.STANDARD_API_NAME_SLOT] + == call[tod.STANDARD_API_NAME_SLOT] + ): + schema = cand + + result.append( + tod.TodStructuredEpisode( + domain=raw.domain, + api_schemas_machine=[schema], + goal_calls_machine=[call], + rounds=[], + ) + ) + return result + + def filter_goals(self, goals): + """ + Some downstream agents may want to filter the goals. + + Override this if so. + """ + return goals + + +class TodSingleGoalAgent(_EpisodeToSingleGoalProcessor, TodGoalAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: If an API schema agent is used, this *must* be used with `TodSingleApiSchemaAgent` since it will be nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleGoal" + + +class TodSingleApiSchemaAgent(_EpisodeToSingleGoalProcessor, TodApiSchemaAgent): + """ + Use as a mixin with classes that also extend + implement TodStructuredDataParser. + + NOTE: Must be used with TodSingleGoalAgent since nonsensicle otherwise. Additionally, this agent will not function properly with UserUtt + SystemUttAndApiCall agent, since episodes will not align. + """ + + def _get_agent_type_suffix(self): + return "SingleApiSchema" + + +###### Agents used for calculating TOD World Metrics based on a dataset. See `tod_world_script` or `parlai/projects/tod_simulator/` for examples. +class TodUserUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "User" agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_USER_UTTERANCE}{self.episode.rounds[self.round_idx].user_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "User" + + +class TodApiCallAndSysUttAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the "System" + agent. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def __init__(self, opt: Opt, shared=None): + # This class represents two "agents" so need to make sure we don't increment episode number (reset) twice + self.already_reset = False + self.api_call_turn = True + super().__init__(opt, shared) + + def act(self): + self.already_reset = False + if tod.STANDARD_API_SCHEMAS in self.observation.get("text", ""): + return { + "text": tod.STANDARD_API_SCHEMAS, + "id": self.id, + "domain": self.episode.domain, + "episode_down": False, + } + + if self.api_call_turn: # comes first, don't iterate round # + result = { + "text": f"{tod.STANDARD_CALL}{self.episode.rounds[self.round_idx].api_call_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + else: + result = { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{self.episode.rounds[self.round_idx].sys_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + + self.api_call_turn ^= True + return result + + def reset(self): + if not self.already_reset: + super().reset() # setup next episode + self.api_call_turn = True + self.already_reset = True + + def _get_agent_type_suffix(self): + return "System" + + +class TodApiResponseAgent(_TodDataDumpAgent): + """ + Agent used to calculate TOD World Metrics on a dataset. Represents the API + Simulator. + + This class should only ever be used with the model-model chat world which will stop + upon seeing the '[DONE]' utterance; may go out of bounds otherwise. + """ + + def act(self): + result = { + "text": f"{tod.STANDARD_RESP}{self.episode.rounds[self.round_idx].api_resp_utt}", + "id": self.id, + "domain": self.episode.domain, + "episode_done": False, + } + self.round_idx += 1 + return result + + def reset(self): + super().reset() # setup next episode + self.round_idx = 0 + + def _get_agent_type_suffix(self): + return "ApiResponse" + + +###### Standalone API agent +class StandaloneApiAgent(Agent): + """ + Trainable agent that saves API calls and responses. + + Use `TodStandaloneApiTeacher` to train this class. For example for a MultiWoz V2.2 + standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the + `.pickle` file to use. + """ + + EMPTY_RESP = { + "text": tod.STANDARD_RESP, + "id": "StandaloneApiAgent", + "episode_done": False, + } + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + group = parser.add_argument_group("TOD Standalone API args") + group.add_argument( + "--exact-api-call", + type=bool, + default=True, + help="Validation-time flag. If true, will return '' if exact api call values not found. If false, will pick response from the same intent with similar api parameters (assuming intent is the same when available)", + ) + + group.add_argument( + "--fail-hard", + type=bool, + default=False, + help="Aids in deugging. Will throw exception if API call not found and '--exact-api-call' is set.", + ) + + group.add_argument( + "--standalone-api-file", + type=str, + default=None, + help="Path to file holding `.pickle` of standalone api for validation (will intelligently strip if suffix included). If not set, assumes the `model_file` argument will contain the `.pickle` file. ", + ) + return parser + + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self.id = "StandaloneApiAgent" + file_key = "model_file" + if self.opt["standalone_api_file"] is not None: + file_key = "standalone_api_file" + self.path_base = self.opt[file_key].replace(".pickle", "") + self.db_path = self.path_base + ".pickle" + self.exact_api_call = self.opt["exact_api_call"] + try: + with (open(self.db_path, "rb")) as openfile: + self.data = pickle.load(openfile) + self.training = True + print("Loaded Standalone API data successfully") + if self.exact_api_call != self.data.get("exact_api_call", True): + raise RuntimeError( + f"Standalone API .pickle file generated with `exact_api_call` of {self.data.get('exact_api_call', False)} but StandaloneApiAgent sets it to {self.exact_api_call}" + ) + except Exception: + print(f"No file at {self.db_path}; ASSUMING WE ARE TRAINING") + self.data = {} + self.data["exact_api_call"] = self.exact_api_call + self.training = True + + def _maybe_filter_prefix(self, text, prefix): + if prefix in text: + return text[len(prefix) :].strip() + return text.strip() + + def act(self): + if not self.observation["text"].startswith(tod.STANDARD_CALL): + return self.EMPTY_RESP + call_text_raw = self.observation["text"] + # decode then reencode the API call so that we get the API calls in a consistent order + call_text = SerializationHelpers.api_dict_to_str( + SerializationHelpers.str_to_api_dict( + call_text_raw[len(tod.STANDARD_CALL) :] + ) + ) + if "labels" in self.observation: + return self._do_train(call_text) + return self._do_fetch(call_text) + + def _do_train(self, call_text): + assert self.training is True + self.data[call_text] = self.observation["labels"][0] + return self.EMPTY_RESP + + def _do_fetch(self, call_text): + if self.exact_api_call: + if self.opt.get("fail_hard", False): + resp = self.data[call_text] + else: + resp = self.data.get(call_text, tod.STANDARD_RESP) + return { + "text": resp, + "id": self.id, + "episode_done": False, + } + + # Not exact case + best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) + if len(best_key) == 0: + return self.EMPTY_RESP + return { + "text": self.data.get(best_key[0], tod.STANDARD_RESP), + "id": self.id, + "episode_done": False, + } + + def shutdown(self): + if self.training: + with (open(self.db_path, "wb")) as openfile: + pickle.dump(self.data, openfile) + print(f"Dumped output to {self.db_path}") + with open(self.path_base + ".opt", "w") as f: + json.dump(self.opt, f) + + +######### Empty agents +class EmptyApiSchemaAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyApiSchemaAgent" + + def act(self): + msg = { + "id": self.getID(), + "text": tod.STANDARD_API_SCHEMAS, + "episode_done": False, + } + return Message(msg) + + +class EmptyGoalAgent(Agent): + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = "EmptyGoalAgent" + + def act(self): + msg = {"id": self.getID(), "text": tod.STANDARD_GOAL, "episode_done": False} + return Message(msg) + + +############# Teachers +class TodSystemTeacher(TodStructuredDataParser, DialogTeacher): + """ + TOD agent teacher which produces both API calls and NLG responses. + + First turn is API Schema grounding, which may be a an empty schema. + Subsequent turns alternate between + 1. User utterance -> API Call + 2. API Response -> System Utterance + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--api-schemas", + type="bool", + default=False, + help="Preempt first turn with intents + required/optional parameters as key/value for given domain", + ) + parser.add_argument( + "--api-jga-record", + type=bool, + default=True, + help="Should we save jga information per api schema?", + ) + parser.add_argument( + "--domain-jga-record", + type=bool, + default=False, + help="Should we save jga information per domain?", + ) + parser.add_argument( + "--domain-nlg-record", + type=bool, + default=False, + help="Should we save nlg information per domain?", + ) + return parser + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_CALL: + if resp.startswith(tod.STANDARD_CALL): + resp = resp[len(tod.STANDARD_CALL) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + domains = ( + [teacher_action["domain"]] if self.opt["domain_jga_record"] else [] + ) + + metrics = SlotMetrics( + teacher_slots=teacher_action["slots"], + predicted_slots=predicted, + avg_jga_nlg_bleu=True, + prefixes=domains, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + if self.opt["api_jga_record"] and len(teacher_action["slots"]) > 0: + teacher = teacher_action["slots"] + slots = list(teacher.keys()) + slots.remove(tod.STANDARD_API_NAME_SLOT) + api_here = ( + "api-" + + teacher[tod.STANDARD_API_NAME_SLOT] + + "--" + + "-".join(slots) + ) + self.metrics.add(f"{api_here}/jga", AverageMetric(teacher == predicted)) + + elif teacher_action["type"] == tod.STANDARD_SYSTEM_UTTERANCE: + domains = ( + [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] + ) + metrics = NlgMetrics( + guess=resp, + labels=labels, + prefixes=domains, + avg_jga_nlg_bleu=True, + ).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if self.opt.get("api_schemas"): + schemas = episode.api_schemas_utt + else: + schemas = "" + yield { + "text": f"{tod.STANDARD_API_SCHEMAS}{schemas}", + "label": f"{tod.STANDARD_API_SCHEMAS}", + "domain": episode.domain, + "type": tod.STANDARD_API_SCHEMAS, + "slots": {}, + }, True + for r in episode.rounds: + yield { + "text": f"{tod.STANDARD_USER_UTTERANCE}{r.user_utt}", + "label": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "domain": episode.domain, + "type": tod.STANDARD_CALL, + "slots": r.api_call_machine, + }, False + yield { + "text": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "label": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "domain": episode.domain, + "slots": r.api_resp_machine, + "type": tod.STANDARD_SYSTEM_UTTERANCE, + }, False + + def _get_agent_type_suffix(self): + return "SystemTeacher" + + +class TodUserSimulatorTeacher(TodStructuredDataParser, DialogTeacher): + """ + Teacher that has `Goal->User Utterance` for its first turn, then `System + Utterance->User Utterance` for all subsequent turns. + """ + + def setup_data(self, fold): + for episode in self.generate_episodes(): + if len(episode.rounds) < 1: + continue + yield { + "text": f"{tod.STANDARD_GOAL}{episode.goal_calls_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[0].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + }, True + for i, r in enumerate(episode.rounds): + if i == len(episode.rounds) - 1: + continue + yield { + "text": f"{tod.STANDARD_SYSTEM_UTTERANCE}{r.sys_utt}", + "label": f"{tod.STANDARD_USER_UTTERANCE}{episode.rounds[i+1].user_utt}", + "domain": episode.domain, + "type": tod.STANDARD_USER_UTTERANCE, + "slots": {}, # slots in agent/user turns are meaningless + }, False + + def custom_evaluation( + self, teacher_action: Message, labels, model_response: Message + ): + resp = model_response.get("text") + if not resp: + return + if teacher_action["type"] == tod.STANDARD_RESP: + if resp.startswith(tod.STANDARD_RESP): + resp = resp[len(tod.STANDARD_RESP) :] + predicted = SerializationHelpers.str_to_api_dict(resp) + + metrics = SlotMetrics(teacher_action["slots"], predicted).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + elif teacher_action["type"] == tod.STANDARD_USER_UTTERANCE: + metrics = NlgMetrics(resp, labels).report() + for key, value in metrics.items(): + self.metrics.add(key, value) + + def _get_agent_type_suffix(self): + return "UserSimulatorTeacher" + + +class TodStandaloneApiTeacher(TodStructuredDataParser, DialogTeacher): + """ + Use this to generate a database for `StandaloneApiAgent`. + + Set this as the teacher with `StandaloneApiAgent` as the agent. Ex for a MultiWoz + V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m + parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` + """ + + def setup_data(self, fold): + # As a default, just put everything in + for fold_overwrite in ["train", "valid", "test"]: + for episode in self.setup_episodes(fold_overwrite): + first = True + for r in episode.rounds: + if len(r.api_call_machine) > 0: + yield { + "text": f"{tod.STANDARD_CALL}{r.api_call_utt}", + "label": f"{tod.STANDARD_RESP}{r.api_resp_utt}", + "id": self.id, + "domain": episode.domain, + }, first + first = False + + def _get_agent_type_suffix(self): + return "StandaloneApiTeacher" diff --git a/parlai/core/tod/tod_test_utils/__init__.py b/parlai/core/tod/tod_test_utils/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/parlai/core/tod/tod_test_utils/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py new file mode 100644 index 00000000000..b1339052764 --- /dev/null +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helpers so we don't need to create agents all over. +""" + +import parlai.core.tod.tod_agents as tod_agents +import parlai.core.tod.tod_core as tod_core + +import os + +API_DATABASE_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "standalone_api_file.pickle" +) + + +def episode_has_broken_api_turn(episode_idx, max_turns): + return episode_idx % 2 == 1 and max_turns > 0 + + +def use_broken_api_calls_this_turn(round_idx, episode_idx): + return episode_idx % 2 == 1 and round_idx % 3 == 1 + + +def make_api_call_machine(round_idx, episode_idx=0, use_broken_mock_api_calls=False): + if round_idx == 0: + return {} + if use_broken_mock_api_calls: + # Hack as a way to test metrics reporting in tod world script + if use_broken_api_calls_this_turn(round_idx, episode_idx): + round_idx = -1 * round_idx + return {tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", "in": round_idx} + + +def make_api_resp_machine(round_idx): + if round_idx == 0: + return {} + return {"out": round_idx} + + +def make_api_schemas_machine(max_rounds): + return [ + { + tod_core.STANDARD_API_NAME_SLOT: f"name_{round_idx}", + tod_core.STANDARD_REQUIRED_KEY: ["in"], + tod_core.STANDARD_OPTIONAL_KEY: [], + } + for round_idx in range(1, max_rounds) + ] + + +def make_goal_calls_machine(max_rounds): + return [make_api_call_machine(x) for x in range(1, max_rounds)] + + +def get_rounds(episode_idx, max_rounds, use_broken_mock_api_calls=False): + return [ + tod_core.TodStructuredRound( + user_utt=f"user_utt_{episode_idx}_{round_idx}", + api_call_machine=make_api_call_machine( + round_idx, episode_idx, use_broken_mock_api_calls + ), + api_resp_machine=make_api_resp_machine(round_idx), + sys_utt=f"sys_utt_{episode_idx}_{round_idx}", + ) + for round_idx in range(max_rounds) + ] + + +def get_round_utts(episode_idx, max_rounds, filter_utts=None): + if max_rounds < 1: + return [] + utts = [ + [ + f"USER: user_utt_{episode_idx}_0", + "APICALL: ", + "APIRESP: ", + f"SYSTEM: sys_utt_{episode_idx}_0", + ] + ] + for i in range(1, max_rounds): + utts.append( + [ + f"USER: user_utt_{episode_idx}_{i}", + f"APICALL: api_name = name_{i} ; in = {i}", + f"APIRESP: out = {i}", + f"SYSTEM: sys_utt_{episode_idx}_{i}", + ] + ) + utts.append( + [ + "USER: [DONE]", + "APICALL: ", + "APIRESP: ", + "SYSTEM: ", + ] + ) + if filter_utts is not None: + utts = [ + [turn for i, turn in enumerate(round_data) if filter_utts[i]] + for round_data in utts + ] + return utts + + +TEST_NUM_EPISODES_OPT_KEY = "test_num_episodes" +TEST_NUM_ROUNDS_OPT_KEY = "test_num_rounds" + +# No api calls in this setup +EPISODE_SETUP__UTTERANCES_ONLY = { + TEST_NUM_ROUNDS_OPT_KEY: 1, + TEST_NUM_EPISODES_OPT_KEY: 1, +} + +# No one call, one goal, one api desscription in this setup +EPISODE_SETUP__SINGLE_API_CALL = { + TEST_NUM_ROUNDS_OPT_KEY: 2, + TEST_NUM_EPISODES_OPT_KEY: 1, +} +# Will start testing multiple api calls + schemas, multi-round logic +EPISODE_SETUP__MULTI_ROUND = {TEST_NUM_ROUNDS_OPT_KEY: 5, TEST_NUM_EPISODES_OPT_KEY: 1} + +# Test that episode logic is correct +EPISODE_SETUP__MULTI_EPISODE = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 8, +} + +# Test that episode + pesky-off-by-one batchinglogic is correct +EPISODE_SETUP__MULTI_EPISODE_BS = { + TEST_NUM_ROUNDS_OPT_KEY: 5, + TEST_NUM_EPISODES_OPT_KEY: 35, +} + + +class TestDataParser(tod_agents.TodStructuredDataParser): + """ + Assume that when we init, we init w/ num of episodes + rounds as opts. + """ + + def __init__(self, opt, shared=None): + opt["datafile"] = "DUMMY" + self.fold = "DUMMY" + # Following lines are only reelvant in training the standalone api teacher + if TEST_NUM_EPISODES_OPT_KEY not in opt: + opt[TEST_NUM_EPISODES_OPT_KEY] = 35 + if TEST_NUM_ROUNDS_OPT_KEY not in opt: + opt[TEST_NUM_ROUNDS_OPT_KEY] = 5 + super().__init__(opt, shared) + + def setup_episodes(self, _): + result = [] + for ep_idx in range(0, self.opt[TEST_NUM_EPISODES_OPT_KEY]): + result.append( + tod_core.TodStructuredEpisode( + goal_calls_machine=[ + make_api_call_machine(x) + for x in range(1, self.opt[TEST_NUM_ROUNDS_OPT_KEY]) + ], + api_schemas_machine=make_api_schemas_machine( + self.opt[TEST_NUM_ROUNDS_OPT_KEY] + ), + rounds=get_rounds( + ep_idx, + self.opt[TEST_NUM_ROUNDS_OPT_KEY], + self.opt.get("use_broken_mock_api_calls", False), + ), + ) + ) + # print(result, self.opt) + return result + + def get_id_task_prefix(self): + return "Test" + + +class SystemTeacher(TestDataParser, tod_agents.TodSystemTeacher): + pass + + +class UserSimulatorTeacher(TestDataParser, tod_agents.TodUserSimulatorTeacher): + pass + + +class StandaloneApiTeacher(TestDataParser, tod_agents.TodStandaloneApiTeacher): + pass + + +class GoalAgent(TestDataParser, tod_agents.TodGoalAgent): + pass + + +class ApiSchemaAgent(TestDataParser, tod_agents.TodApiSchemaAgent): + pass + + +class SingleGoalAgent(TestDataParser, tod_agents.TodSingleGoalAgent): + pass + + +class SingleApiSchemaAgent(TestDataParser, tod_agents.TodSingleApiSchemaAgent): + pass + + +# Tested in tod world code +class UserUttAgent(TestDataParser, tod_agents.TodUserUttAgent): + pass + + +# Tested in tod world code +class ApiCallAndSysUttAgent(TestDataParser, tod_agents.TodApiCallAndSysUttAgent): + pass diff --git a/pytest.ini b/pytest.ini index 9aec92c0a89..d4095288194 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,3 +12,4 @@ markers = unit internal nofbcode + tod diff --git a/tests/tod/__init__.py b/tests/tod/__init__.py new file mode 100644 index 00000000000..240697e3247 --- /dev/null +++ b/tests/tod/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py new file mode 100644 index 00000000000..5383c72416d --- /dev/null +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests different (more complicated) slot metrics. +""" + +import unittest + +import copy +import parlai.core.tod.tod_core as tod_core +import parlai.core.tod.tod_test_utils.test_agents as test_agents + + +class TestTodAgentsAndTeachersBase(unittest.TestCase): + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + full_opts["episodes_randomization_seed"] = -1 # no random here + return class_type(full_opts) + + def dump_single_utt_per_episode_agent_text(self, class_type, round_opt, opt): + agent = self.setup_agent_or_teacher(class_type, round_opt, opt) + result = [] + while not agent.epoch_done(): + result.append(agent.act()["text"]) + agent.reset() + return result + + def dump_teacher_text(self, class_type, round_opt, opt): + """ + Array where [episode_idx][turn_idx][text=0,label=1] + """ + teacher = self.setup_agent_or_teacher(class_type, round_opt, opt) + data = [] + here = [] + for x, new in teacher.setup_data("dummy"): + if new and len(here) > 0: + data.append(copy.deepcopy(here)) + here = [] + here.append([x["text"], x["label"]]) + if len(here) > 0: + data.append(here) + return data + + def _test_roundDataCorrect(self): + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__UTTERANCES_ONLY) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + + +class TestSystemTeacher(TestTodAgentsAndTeachersBase): + def test_apiSchemas_with_yesApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": True}, + ) + self.assertEqual( + values[0][0][0], + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(2) + ), + ) + + def test_apiSchemas_with_noApiSchemas(self): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__SINGLE_API_CALL, + {"api_schemas": False}, + ) + self.assertEqual(values[0][0][0], "APIS: ") + + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.SystemTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + for utt in utts: + comp.append([utt[0], utt[1]]) + comp.append([utt[2], utt[3]]) + # Skip context turn cause we check it above + self.assertEqual(episode[1:], comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestUserTeacher(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + values = self.dump_teacher_text(test_agents.UserSimulatorTeacher, config, {}) + for episode_idx, episode in enumerate(values): + utts = test_agents.get_round_utts(episode_idx, max_rounds) + comp = [] + comp.append( + [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ), + utts[0][0], + ] + ) + last_sys = utts[0][3] + for i in range(1, len(utts)): + comp.append([last_sys, utts[i][0]]) + last_sys = utts[i][3] + self.assertEqual(episode, comp) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.GoalAgent, config, {} + ) + + goal_text = [ + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.ApiSchemaAgent, config, {} + ) + + apis_texts = [ + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(max_rounds) + ) + for _ in range(max_episodes) + ] + + self.assertEqual(values, apis_texts) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + + goal_text = [] + for _ in range(max_episodes): + goals = test_agents.make_goal_calls_machine(max_rounds) + for x in goals: + goal_text.append( + "GOAL: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + + self.assertEqual(values, goal_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + def _test_roundDataCorrect_helper(self, config): + max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] + values = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + apis_text = [] + for _ in range(max_episodes): + apis = test_agents.make_api_schemas_machine(max_rounds) + for x in apis: + apis_text.append( + "APIS: " + tod_core.SerializationHelpers.list_of_maps_to_str([x]) + ) + self.assertEqual(values, apis_text) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestSingleGoalWithSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): + """ + Make sure the SingleGoal + SingleApiSchema agents correspond. + """ + + def _test_roundDataCorrect_helper(self, config): + goals = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleGoalAgent, config, {} + ) + apis = self.dump_single_utt_per_episode_agent_text( + test_agents.SingleApiSchemaAgent, config, {} + ) + + for i in range(len(goals)): + goal = tod_core.SerializationHelpers.str_to_goals(goals[i][len("GOALS:") :]) + api = tod_core.SerializationHelpers.str_to_api_schemas( + apis[i][len("APIS:") :] + ) + self.assertEqual( + goal[0].get("api_name", None), api[0].get("api_name", None) + ) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TestLowShot(TestTodAgentsAndTeachersBase): + FEW_SHOT_SAMPLES = [0, 1, 5, 15] + PERCENTAGES = [0, 0.1, 0.3, 0.5] + + def setup_agent_or_teacher(self, class_type, round_opt, opt): + full_opts = {**round_opt, **opt} + full_opts["datatype"] = "DUMMY" + full_opts["datafile"] = "DUMMY" + return class_type(full_opts) + + def test_few_shot_lengths_correct(self): + def helper(n_shot): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, + { + "episodes_randomization_seed": 0, + "n_shot": n_shot, + }, + ) + self.assertEqual(len(values), n_shot) + + for i in self.FEW_SHOT_SAMPLES: + helper(i) + + def _test_subsets(self, data_dumps): + for i in range(len(data_dumps) - 1): + small = data_dumps[i] + larger = data_dumps[i + 1] + for i, episode in enumerate(small): + self.assertEqual(episode, larger[i]) + + def test_few_shot_subset(self): + def helper(n_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE, + { + "episodes_randomization_seed": seed, + "n_shot": n_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.FEW_SHOT_SAMPLES] + self._test_subsets(data_dumps_seed_three) + self.assertNotEqual(data_dumps_seed_zero[-1], data_dumps_seed_three[-1]) + + def test_percent_shot_lengths_correct(self): + def helper(percent_shot, correct): + values = self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": 0, + "percent_shot": percent_shot, + }, + ) + self.assertEqual(len(values), correct) + + helper(0, 0) + helper(0.1, 3) + helper(0.3, 10) + + def test_percent_shot_subset(self): + def helper(percent_shot, seed): + return self.dump_teacher_text( + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + { + "episodes_randomization_seed": seed, + "percent_shot": percent_shot, + }, + ) + + data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_zero) + data_dumps_seed_three = [helper(i, 3) for i in self.PERCENTAGES] + self._test_subsets(data_dumps_seed_three) + + def test_correct_throw_when_both_shots_defined(self): + self.assertRaises( + RuntimeError, + self.dump_teacher_text, + test_agents.SystemTeacher, + test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes + {"episodes_randomization_seed": 0, "percent_shot": 0.3, "n_shot": 3}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py new file mode 100644 index 00000000000..aec4aba40f6 --- /dev/null +++ b/tests/tod/test_tod_teacher_metrics.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from math import isnan + +from parlai.core.metrics import AverageMetric +from parlai.core.tod.teacher_metrics import SlotF1Metric, SlotMetrics + + +class TestSlotF1Metric(unittest.TestCase): + """ + Test SlotF1Metric. + """ + + def test_slot_f1_metric_inputs(self): + slots_p_r_and_f1 = [ + (None, None, float("nan")), + (None, AverageMetric(0.0), float("nan")), + (AverageMetric(0.0), AverageMetric(0.0), float("nan")), + (AverageMetric(1), AverageMetric(1), 1.0), + (AverageMetric(1), AverageMetric(0), 0.0), + (AverageMetric(0.25), AverageMetric(0.75), 0.375), + ] + for slot_p, slot_r, slot_f1 in slots_p_r_and_f1: + actual_slot_f1 = SlotF1Metric(slot_p=slot_p, slot_r=slot_r).value() + if isnan(slot_f1): + self.assertTrue(isnan(actual_slot_f1)) + else: + self.assertEqual(slot_f1, actual_slot_f1) + + def test_slot_f1_metric_addition(self): + a = SlotF1Metric(slot_p=1) + b = SlotF1Metric(slot_r=0) + c = SlotF1Metric(slot_p=AverageMetric(numer=2, denom=3), slot_r=1) + d = a + b + c + # Slot P should be 3/4 = 0.75; slot R should be 1/2 = 0.5 + self.assertEqual(0.6, d.value()) + + +empty_slots = {} +basic_slots = {"a": "a_val", "b": "b_val", "c": "c_val"} +partial_slots = {"a": "a_val", "other": "other_val"} + + +class TestSlotMetrics(unittest.TestCase): + def test_base_slot_metrics(self): + cases = [ + (empty_slots, empty_slots, {"jga": 1}), + ( + basic_slots, + basic_slots, + {"jga": 1, "slot_p": 1, "slot_r": 1, "slot_f1": 1}, + ), + ( + basic_slots, + partial_slots, + {"jga": 0, "slot_p": 0.5, "slot_r": float(1.0 / 3), "slot_f1": 0.4}, + ), + ] + for teacher, predicted, result in cases: + metric = SlotMetrics( + teacher_slots=teacher, + predicted_slots=predicted, + ) + for key in result: + self.assertEqual(result[key], metric.report()[key]) + + +if __name__ == "__main__": + unittest.main() From 638eb286ed5bc14a5e627e9650c4cb5d709fb4e4 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Mon, 15 Nov 2021 20:29:54 -0800 Subject: [PATCH 07/23] [TOD] World, world metrics, script, tests See documentation in `tod_world_script.py` for usage. --- parlai/core/tod/README.md | 6 + parlai/core/tod/tod_core.py | 7 +- .../tod_test_utils/standalone_api_file.pickle | Bin 0 -> 232 bytes parlai/core/tod/tod_test_utils/test_agents.py | 1 - parlai/core/tod/tod_world.py | 307 +++++++++++++ parlai/core/tod/world_metrics.py | 116 +++++ parlai/core/tod/world_metrics_handlers.py | 181 ++++++++ .../scripts/distributed_tod_world_script.py | 33 ++ parlai/scripts/tod_world_script.py | 408 +++++++++++++++++ projects/tod_simulator/README.md | 1 - .../world_metrics/extended_world_metrics.py | 411 ++++++++++++++++++ tests/tod/test_tod_world_and_script.py | 225 ++++++++++ tests/tod/test_tod_world_metrics.py | 243 +++++++++++ tests/tod/test_tod_world_metrics_in_script.py | 257 +++++++++++ tests/tod/test_tod_world_script_metrics.py | 154 +++++++ 15 files changed, 2346 insertions(+), 4 deletions(-) create mode 100644 parlai/core/tod/README.md create mode 100644 parlai/core/tod/tod_test_utils/standalone_api_file.pickle create mode 100644 parlai/core/tod/tod_world.py create mode 100644 parlai/core/tod/world_metrics.py create mode 100644 parlai/core/tod/world_metrics_handlers.py create mode 100644 parlai/scripts/distributed_tod_world_script.py create mode 100644 parlai/scripts/tod_world_script.py delete mode 100644 projects/tod_simulator/README.md create mode 100644 projects/tod_simulator/world_metrics/extended_world_metrics.py create mode 100644 tests/tod/test_tod_world_and_script.py create mode 100644 tests/tod/test_tod_world_metrics.py create mode 100644 tests/tod/test_tod_world_metrics_in_script.py create mode 100644 tests/tod/test_tod_world_script_metrics.py diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md new file mode 100644 index 00000000000..f1ebc50e4d6 --- /dev/null +++ b/parlai/core/tod/README.md @@ -0,0 +1,6 @@ +# Core classes for Task-Oriented Dialog (TOD) + +For understanding usage of these classes, start with `tod_agents.py` (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format. + +As a convention, files that have elements that are expected to be referenced outside of this directory (and outisde of `parlai/projects/tod_simulator`) are prefixed wth `tod_`. + diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py index 76ee8005a74..e53be5c7376 100644 --- a/parlai/core/tod/tod_core.py +++ b/parlai/core/tod/tod_core.py @@ -160,7 +160,10 @@ def inner_list_join(cls, values): @classmethod def inner_list_split(cls, s): - return s.split(", ") + split = s.split(", ") + if len(split) == 1: + return split[0] + return set(split) @classmethod def maybe_inner_list_join(cls, values): @@ -193,7 +196,7 @@ def str_to_api_dict(cls, string): continue name, value = slot_str.split(" = ", 1) name = name.strip() - value = value.strip() + value = SerializationHelpers.inner_list_split(value.strip()) result[name] = value return result diff --git a/parlai/core/tod/tod_test_utils/standalone_api_file.pickle b/parlai/core/tod/tod_test_utils/standalone_api_file.pickle new file mode 100644 index 0000000000000000000000000000000000000000..b27e3b297415f56f1d054caf3b4045eec85505a0 GIT binary patch literal 232 zcmZo*t}SHHh>&7nU`Q;;jL%EVO;xZ}08#OV3f2mlc|e|FA!CF9P=RBBXOL@ffR#di zX$e@E39CLMm_DOIW^DS53R$q~GluCiE@Z`~&$y5ct3DH$K9fRrZ2C+JIZE{a-J&|v literal 0 HcmV?d00001 diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b1339052764..b7639efea36 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -171,7 +171,6 @@ def setup_episodes(self, _): ), ) ) - # print(result, self.opt) return result def get_id_task_prefix(self): diff --git a/parlai/core/tod/tod_world.py b/parlai/core/tod/tod_world.py new file mode 100644 index 00000000000..d95a8706478 --- /dev/null +++ b/parlai/core/tod/tod_world.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Class for running task-oriented dialogue chats. See long comment on TodWorld for +functionality descriptions. + +Metrics calculated from these simulations are documented in `world_metrics.py` (for +general usage) and `world_metrics_handlers.py` (for specific metric calculations) +""" +from parlai.core.metrics import Metric, LegacyMetric +from parlai.core.message import Message +from parlai.core.opt import Opt +from parlai.core.worlds import World +from parlai.agents.local_human.local_human import LocalHumanAgent +from parlai.utils.misc import display_messages + +import parlai.core.tod.tod_core as tod +import parlai.core.tod.world_metrics as tod_metrics + +import sys +import copy + +# Following needs to be kept consistent with opt settings/tod script +USER_UTT_IDX = 0 +API_CALL_IDX = 1 +API_RESP_IDX = 2 +SYSTEM_UTT_IDX = 3 +API_SCHEMA_GROUNDING_IDX = 4 +GOAL_GROUNDING_IDX = 5 +AGENT_COUNT = 6 + +SPEAKER_TO_NAME = { + USER_UTT_IDX: tod.TodAgentType.USER_UTT_AGENT, + API_CALL_IDX: tod.TodAgentType.API_CALL_AGENT, + API_RESP_IDX: tod.TodAgentType.API_RESP_AGENT, + SYSTEM_UTT_IDX: tod.TodAgentType.SYSTEM_UTT_AGENT, + API_SCHEMA_GROUNDING_IDX: tod.TodAgentType.API_SCHEMA_GROUNDING_AGENT, + GOAL_GROUNDING_IDX: tod.TodAgentType.GOAL_GROUNDING_AGENT, +} + +NAME_TO_IDX = {v: k for k, v in SPEAKER_TO_NAME.items()} + + +class TodWorld(World): + """ + Base world for running TOD model-model chats. Following agents. + + * User utt agent + * API call agent + * Currently assumed to be same as system utt agent in script code, though used as if separate in this world. + * API responder agent + * System utt agent + * API schema groundinger agent (given to api call + response agent) + * Goal groundinger agent (given to user) + + As is standard for ParlAI, these agents may be models or may be standalone classes that extend the "Agent" class. The models for these *are* expected to have their utterances in a standard format. + + Note that we expect these to be passed in via the opt manually, since some assumptions of regular ParlAI Worlds (ex. task = agent[0], model = agent[1]) are broken here since there is no "task agent" and one agent can be two "roles" (ex. system agent also making API calls) + """ + + def __init__(self, opt: Opt, agents=None, shared=None): + super().__init__(opt, agents, shared) + self.batchsize = opt["batchsize"] + self.batch_agents = [] + self.batch_acts = [] + self.batch_goals = [] # for case when num_episodes < batchsize + self.batch_tod_world_metrics = [] + for i in range(self.batchsize): + here_agents = [] + for j, agent in enumerate(agents): + if ( + j == SYSTEM_UTT_IDX + ): # handle separately cause we expect it to be same as API_CALL agent + here_agents.append(here_agents[API_CALL_IDX]) + continue + share = agent.share() + batch_opt = copy.deepcopy(share["opt"]) + batch_opt["batchindex"] = i + here_agents.append(share["class"](batch_opt, share)) + self.batch_agents.append(here_agents) + self.batch_acts.append([Message.padding_example()] * 4) + self.batch_tod_world_metrics.append(tod_metrics.TodMetrics()) + self.end_episode = [False] * self.batchsize + + self.max_turns = self.opt.get("max_turns", 30) + self.turns = 0 + self.need_grounding = True + + def grounding(self): + """ + Preempt with goal and schema-based intent schemas. + + As a logging hack, we stick the schema gronding in as a user utterance, but + manually pass the value in to the relevant API call/resp agent, since passing it + to the API call agent elsewhere is a little awkward. Similarly, we stick the + goal as a system utterance so that it is captured in logging. However, we do not + pass it in manually, since getting the user utterance will be the first turn of + `parley()`. + """ + self._observe_and_act( + SYSTEM_UTT_IDX, # Doesn't matter, empty at this point + USER_UTT_IDX, # Hack in to a place that'll look nice when printing + f"getting API schema grounding. (Must start with `{tod.STANDARD_API_SCHEMAS}`)", + API_SCHEMA_GROUNDING_IDX, + ) + + self._observe_and_act( + USER_UTT_IDX, + API_CALL_IDX, + "responding to api schema grounding (empty enter is usually fine) ", + ) + self._observe_and_act( + USER_UTT_IDX, + API_RESP_IDX, + "responding to api schema grounding (empty enter is usually fine)", + ) + + self._observe_and_act( + SYSTEM_UTT_IDX, # Doesn't matter for the most part, but want something empty + SYSTEM_UTT_IDX, # Hack into a place per comment above + f"getting goal grounding. (Must start with `{tod.STANDARD_GOAL}`)", + GOAL_GROUNDING_IDX, + ) + self.batch_goals = [act[SYSTEM_UTT_IDX] for act in self.batch_acts] + self.turns = 0 + + def parley(self): + if self.need_grounding: + self.grounding() + self.need_grounding = False + + else: + self._observe_and_act(SYSTEM_UTT_IDX, USER_UTT_IDX) + self._observe_and_act(USER_UTT_IDX, API_CALL_IDX) + self._observe_and_act(API_CALL_IDX, API_RESP_IDX) + self._observe_and_act(API_RESP_IDX, SYSTEM_UTT_IDX) + + self.turns += 1 + self.update_counters() + + def _observe_and_act( + self, observe_idx, act_idx, info="for regular parley", override_act_idx=None + ): + act_agent_idx = override_act_idx if override_act_idx else act_idx + act_agent = self.agents[act_agent_idx] + record_output_idx = act_idx + if hasattr(act_agent, "batch_act"): + batch_observations = [] + for i in range(self.batchsize): + if not self.end_episode[i]: + observe = self.batch_acts[i][observe_idx] + observe = self.batch_agents[i][act_agent_idx].observe(observe) + batch_observations.append(Message(observe)) + else: + # We're done with this episode, so just do a pad. + # NOTE: This could cause issues with RL down the line + batch_observations.append(Message.padding_example()) + self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} + batch_actions = act_agent.batch_act(batch_observations) + for i in range(self.batchsize): + if self.end_episode[i]: + continue + self.batch_acts[i][record_output_idx] = batch_actions[i] + self.batch_agents[i][record_output_idx].self_observe(batch_actions[i]) + else: # Run on agents individually + for i in range(self.batchsize): + act_agent = ( + self.batch_agents[i][override_act_idx] + if override_act_idx + else self.batch_agents[i][act_idx] + ) + if hasattr(act_agent, "episode_done") and act_agent.episode_done(): + self.end_episode[i] = True + if self.end_episode[i]: + # Following line exists because: + # 1. Code for writing converseations is not hapy if an "id" does not exists with a sample + # 2. Because of the `self.end_episode` code, no agent will see this example anyway. + self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} + continue + act_agent.observe(self.batch_acts[i][observe_idx]) + if isinstance(act_agent, LocalHumanAgent): + print( + f"Getting message for {SPEAKER_TO_NAME[record_output_idx]} for {info} in batch {i}" + ) + try: + self.batch_acts[i][record_output_idx] = act_agent.act() + except StopIteration: + self.end_episode[i] = True + for i in range(self.batchsize): + if self.end_episode[i]: + continue + self.batch_tod_world_metrics[i].handle_message( + self.batch_acts[i][record_output_idx], SPEAKER_TO_NAME[act_agent_idx] + ) + if tod.STANDARD_DONE in self.batch_acts[i][record_output_idx].get( + "text", "" + ): + # User models trained to output a "DONE" on last turn; same with human agents. + self.end_episode[i] = True + + def report(self): + """ + Report all metrics of all subagents + of this world in aggregate. + """ + metrics_separate = [] + for i in range(self.batchsize): + here_metrics = self.batch_tod_world_metrics[i].report() + for name, agent in [ + (SPEAKER_TO_NAME[j], self.batch_agents[i][j]) + for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX] + ]: + name_prefix = name[:-6] # strip "_agent" + if hasattr(agent, "report"): + m = agent.report() + if m is None: + continue + for k, v in m.items(): + if not isinstance(v, Metric): + v = LegacyMetric(v) + here_metrics[f"{name_prefix}_{k}"] = v + metrics_separate.append(here_metrics) + metrics = metrics_separate[0] + for i in range(1, self.batchsize): + for k, v in metrics_separate[i].items(): + if k not in metrics: + metrics[k] = v + else: + metrics[k] = metrics[k] + v + return metrics + + def reset(self): + """ + Resets state of world; also sets up episode metrics. + """ + super().reset() + self.need_grounding = True + self.turns = 0 + + self.last_batch_episode_metrics = [] + self.batch_acts = [] + for i in range(self.batchsize): + for agent in self.batch_agents[i]: + agent.reset() + self.batch_acts.append([None] * 4) + + self.batch_tod_world_metrics[i].episode_reset() + metrics = self.batch_tod_world_metrics[i].get_last_episode_metrics() + if metrics: + self.last_batch_episode_metrics.append(metrics) + self.end_episode = [False] * self.batchsize + + def get_last_batch_episode_metrics(self): + return self.last_batch_episode_metrics + + def get_last_batch_goals(self): + return self.batch_goals + + def episode_done(self): + if self.turns >= self.max_turns or all(self.end_episode): + return True + for i in range(self.batchsize): + for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX]: + if ( + self.batch_acts[i][j] is not None + and tod.STANDARD_DONE in self.batch_acts[i][j].get("text", "") + ) or ( + hasattr(self.batch_agents[i][j], "episode_done") + and self.batch_agents[i][j].episode_done() + ): + self.end_episode[i] = True + return all(self.end_episode) + + def epoch_done(self): + for agent in self.agents: + if agent.epoch_done(): + return True + + def num_episodes(self): + result = sys.maxsize + for agent in self.agents: + if hasattr(agent, "num_episodes") and agent.num_episodes() > 0: + result = min(result, agent.num_episodes()) + if result == sys.maxsize: + return 0 + return result + + def get_batch_acts(self): + return self.batch_acts + + def display(self): + s = "[--batchsize " + str(self.batchsize) + "--]\n" + for i in range(self.batchsize): + s += "[batch " + str(i) + ":]\n" + s += display_messages( + self.batch_acts[i], + ignore_agent_reply=self.opt.get("ignore_agent_reply", False), + add_fields=self.opt.get("display_add_fields", ""), + prettify=self.opt.get("display_prettify", False), + max_len=self.opt.get("max_display_len", 1000), + verbose=self.opt.get("verbose", False), + ) + s += "\n" + s += "[--end of batch--]\n" + return s diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py new file mode 100644 index 00000000000..4e8ba4555e4 --- /dev/null +++ b/parlai/core/tod/world_metrics.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Wrapper object holding metrics for TODWorld. + +This class is in its own file to prevent circular dependencies + monolithic files. +""" + +from parlai.core.message import Message +from parlai.core.metrics import ( + Metrics, +) +from parlai.core.tod.tod_core import ( + TodAgentType, + TOD_AGENT_TYPE_TO_PREFIX, + SerializationHelpers, + STANDARD_GOAL, +) +from typing import Any, Dict +import parlai.core.tod.world_metrics_handlers as world_metrics_handlers + +# Change the following to define which Metrics Handlers are used in TodWorld. +# The ones used below are from `world_metrics_handlers.py` only. However, See `parlai/projects/tod_simulator/world_metrics/extended_world_metrics.py` for others. + +WORLD_METRIC_HANDLERS = [ + world_metrics_handlers.AllGoalApiCallSuccessMetricsHandler, + world_metrics_handlers.UserGeneratedDoneMetricHandler, +] + + +class TodMetrics(Metrics): + """ + Helper container which encapsulates TOD metrics and does some basic prepocessing to + handlers to calculate said metrics. + + This class should generally not need to be changed; add new metrics handlers to + `WORLD_METRIC_HANDLERS` (or otherwise override `self.handlers` of this class) to + change metrics actively being used. + """ + + def __init__(self, shared: Dict[str, Any] = None) -> None: + super().__init__(shared=shared) + self.handlers = [x() for x in WORLD_METRIC_HANDLERS] + self.convo_started = False + self.last_episode_metrics = Metrics() + + def handle_message(self, message: Message, agent_type: TodAgentType): + if "text" not in message: + return + if agent_type == TodAgentType.GOAL_GROUNDING_AGENT and len( + message["text"] + ) > len(STANDARD_GOAL): + # Only count a conversation as started if there is a goal. + self.convo_started = True + for handler in self.handlers: + metrics = self._handle_message_impl(message, agent_type, handler) + if metrics is not None: + for name, metric in metrics.items(): + if metric is not None: + self.add(name, metric) + + def _handle_message_impl( + self, + message: Message, + agent_type: TodAgentType, + handler: world_metrics_handlers.TodMetricsHandler, + ): + prefix_stripped_text = message["text"].replace( + TOD_AGENT_TYPE_TO_PREFIX[agent_type], "" + ) + if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: + return handler.handle_api_schemas( + message, + SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + ) + if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: + return handler.handle_goals( + message, SerializationHelpers.str_to_goals(prefix_stripped_text) + ) + if agent_type is TodAgentType.USER_UTT_AGENT: + return handler.handle_user_utt(message, prefix_stripped_text) + if agent_type is TodAgentType.API_CALL_AGENT: + return handler.handle_api_call( + message, SerializationHelpers.str_to_api_dict(prefix_stripped_text) + ) + if agent_type is TodAgentType.API_RESP_AGENT: + return handler.handle_api_resp( + message, SerializationHelpers.str_to_api_dict(prefix_stripped_text) + ) + if agent_type is TodAgentType.SYSTEM_UTT_AGENT: + return handler.handle_sys_utt(message, prefix_stripped_text) + + def get_last_episode_metrics(self): + """ + This is a bit of a hack so that we can report whether or not a convo has + successfully hit all goals and associate this with each episode for the purposes + of doing filtering. + """ + return self.last_episode_metrics + + def episode_reset(self): + self.last_episode_metrics = None + if self.convo_started: + self.last_episode_metrics = Metrics() + for handler in self.handlers: + metrics = handler.get_episode_metrics() + handler.episode_reset() + if metrics is not None: + for name, metric in metrics.items(): + if metric is not None: + self.add(name, metric) + self.last_episode_metrics.add(name, metric) + self.convo_started = False diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py new file mode 100644 index 00000000000..3f4b68477fe --- /dev/null +++ b/parlai/core/tod/world_metrics_handlers.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Metrics handlers - ie, objects that handle generations from Tod World and calculates metrics from them. + +Note that only metrics handler classes in `WORLD_METRIC_HANDLERS` (of `world_metrics.py`) are actively being recorded as metrics. +""" + +from parlai.core.message import Message +from parlai.core.metrics import ( + Metric, + AverageMetric, +) +from parlai.core.tod.tod_core import ( + STANDARD_DONE, +) +from typing import Dict, List, Optional + +METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests + + +def register_metrics_handler(cls): + METRICS_HANDLER_CLASSES_TEST_REGISTRY.add(cls) + return cls + + +class TodMetricsHandler: + """ + Base class for Tod Metrics handlers. Extend this class then add them to + `WORLD_METRIC_HANDLERS` to use. If you would like the class to be exposed to tests, + add the Metrics Handler to `METRICS_HANDLER_CLASSES_TEST_REGISTRY` via annotating + with `@register_metrics_handler`. + + The `TodMetrics` class will, on this class + 1. call `__init__` (which internally calls `episode_reset()`) to begin with. + 2. call each of the `handle..()` functions as the appropriate turns occur + 3. call `get_episode_metrics()` then `episode_reset()` at the end of the episode + + The `handle..()` should be used to set intermediate state within the class and `episode_reset()` should be used to clear this state. + + The output of the `handle..()` and `get_episode_metrics()` functions are both `Optional[Dict[str, Metric]]`s. Metrics from both of these paths will be aggregated and reported to `TodMetrics`, so which one to use is mostly a matter of preference, though + 1. one should take care to only use one or the other and not both, to avoid double-counting + 2. those from `get_episode_metrics()` will be recorded per-episode and saved to `tod_world_script`'s report as well + + `UserGeneratedDoneMetricHandler` in this file, which collects metrics about frequency of seeing the "[DONE]" token on User utterances and also records conversation length, is a fairly straightforward example of usage. + + Other tried (but not in current active use) Metrics Handers are in `projects/tod_simulator/world_metrics/extended_world_metrics.py`. + """ + + def __init__(self): + self.episode_reset() + + def episode_reset(self): + pass + + def handle_api_schemas( + self, message: Message, api_schemas: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.api_schemas = api_schemas + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.goals = goals + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_api_resp( + self, message: Message, api_resp: Dict + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + pass + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + pass + + +################################ +# Functions and classes associated with calculating statistics between API Calls and Goals. +def goals_hit_helper( + goals: List[Dict], turnDict: List[Dict], permissive=False +) -> (AverageMetric, AverageMetric, AverageMetric): + """ + Helper function that aids in seeing if the API calls the system has attempted to + make manages to meet the goals the conversation has. + + Return values: + * if all goals hit + * # of turns it took to hit all goals (or None) + * fraction of goals hit + """ + goals_left = goals + + def exact_match(goal, turn): # if and only if + return goal == turn + + def permissive_match(goal, turn): # guess is superset + for key in goal: + if turn.get(key, "definitelyNotIn") != goal[key]: + return False + return True + + compare_func = permissive_match if permissive else exact_match + + for i, turn in enumerate(turnDict): + goals_left = [goal for goal in goals_left if not compare_func(goal, turn)] + if len(goals_left) == 0: + return AverageMetric(True), AverageMetric(i + 1), AverageMetric(1) + return ( + AverageMetric(False), + AverageMetric(0), + AverageMetric(len(goals) - len(goals_left), len(goals)), + ) + + +class _ApiCallGoalInteractionHelper(TodMetricsHandler): + """ + Base class for storing details about valid API calls (information about Goals + handled in TodMetricsHandler) + """ + + def episode_reset(self): + self.api_turns = [] + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + if len(api_call) > 0: + self.api_turns.append(api_call) + + +@register_metrics_handler +class AllGoalApiCallSuccessMetricsHandler(_ApiCallGoalInteractionHelper): + """ + Calculates synthetic Task Success + related metrics for converseations. + + Test coverage of this class is with `LegacyGoalApiCallInteractionsMetricsHandler` + """ + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + all_goals_hit, _, _ = goals_hit_helper(self.goals, self.api_turns) + call_attempts = len(self.api_turns) + return { + "synthetic_task_success": all_goals_hit, + "api_call_attempts": AverageMetric(call_attempts), + } + + +@register_metrics_handler +class UserGeneratedDoneMetricHandler(TodMetricsHandler): + def episode_reset(self): + self.done_seen = False + self.turn_count = 0 + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + self.done_seen |= STANDARD_DONE in message["text"] + self.turn_count += 1 + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + result = {"done_seen": AverageMetric(self.done_seen)} + if self.done_seen: + result["round_count_done_seen"] = AverageMetric(self.turn_count) + result["rounds_count_all_conversations"] = AverageMetric(self.turn_count) + return result diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py new file mode 100644 index 00000000000..87c8333ec41 --- /dev/null +++ b/parlai/scripts/distributed_tod_world_script.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Distributed script for running TOD model-model chats. + +Not to be called directly; should be called from SLURM +""" + +from parlai.scripts.tod_world_script import ( + TodWorldScript, +) +from parlai.core.script import ParlaiScript +import parlai.utils.distributed as distributed_utils + + +class DistributedTodWorldScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = TodWorldScript.setup_args() + parser.add_distributed_training_args() + parser.add_argument("--port", type=int, default=61337, help="TCP port number") + return parser + + def run(self): + with distributed_utils.slurm_distributed_context(self.opt) as opt: + return TodWorldScript(opt).run() + + +if __name__ == "__main__": + DistributedTodWorldScript.main() diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py new file mode 100644 index 00000000000..af4f963b71f --- /dev/null +++ b/parlai/scripts/tod_world_script.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base script for running TOD model-model chats. +""" + +import json +from copy import deepcopy +from shutil import copyfile +import os + +import parlai.utils.logging as logging +import parlai.core.tod.tod_world as tod_world +import parlai.core.tod.tod_agents as tod_world_agents +from parlai.core.agents import create_agent +from parlai.core.metrics import dict_report, aggregate_unnamed_reports +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript, register_script +from parlai.utils.distributed import ( + is_primary_worker, + all_gather_list, + is_distributed, + get_rank, + sync_object, + num_workers, +) +from parlai.utils.io import PathManager +from parlai.utils.misc import TimeLogger, nice_report +from parlai.utils.world_logging import WorldLogger + + +class TodWorldLogger(WorldLogger): + """ + WorldLogger has most of what we need. + + We could if-class this logic in it directly, but inheritence + override here is + neater. + """ + + def _is_batch_world(self, world): + return True + + def _log_batch(self, world): + batch_acts = world.get_batch_acts() + for i, acts in enumerate(batch_acts): + # filter out for empty + acts = [act for act in acts if act["id"] != "" and act["text"] != ""] + if len(acts) > 0: + self._add_msgs(acts, idx=i) + if world.episode_done(): + self.reset_world(idx=i) + + +class TodWorldParser(ParlaiParser): + def add_extra_args(self, args=None): + super().add_extra_args(args) + parsed = vars(self.parse_known_args(args, nohelp=True)[0]) + # Also load extra args options if a file is given. + if parsed.get("init_opt") is not None: + try: + self._load_known_opts(parsed.get("init_opt"), parsed) + except FileNotFoundError: + # don't die if -o isn't found here. See comment in second call + # later on. + pass + parsed = self._infer_datapath(parsed) + + partial = Opt(parsed) + + for model in [ + "system_model", + "user_model", + "api_schema_grounding_model", + "goal_grounding_model", + "api_resp_model", + ]: + if ( + model in partial + and partial[model] is not None + and len(partial[model]) > 0 + ): + self.add_model_subargs(partial[model], partial) + + for model_file_prefix in ["system", "user"]: + key = model_file_prefix + "_model_file" + if key in partial and partial[key] and len(partial[key]) > 0: + model_name = self._get_model_name_from_model_file(key, partial) + self.add_model_subargs(model_name, partial) + + def _get_model_name_from_model_file(self, key, opt): + """ + Get the model name from either `--model` or `--model-file`. + """ + # try to get model name from model opt file + model_file = opt.get(key, None) + optfile = model_file + ".opt" + new_opt = Opt.load(optfile) + model = new_opt.get("model", None) + return model + + +@register_script("tod_world_script") +class TodWorldScript(ParlaiScript): + @classmethod + def setup_tod_args(cls, parser: ParlaiParser): + tod_args = parser.add_argument_group( + "TOD World Script Agent arguments. NOTE: Agents setup with this path will be able to take command line arguments, whereas those set from `-o` or `--init-opt` will not." + ) + tod_args.add_argument( + "--system-model-file", + default="", + help="Define the system model for the chat. Exactly one of this or system-model must be specified", + ) + + tod_args.add_argument( + "--system-model", + default="", + help="Define the system agent for the chat. Exactly one of this or system-model-file must be specified", + ) + + tod_args.add_argument( + "--user-model-file", + default="", + help="Define the user model for the chat. Exactly one of this user-model must be specified. Currently assumed to be the API Call creation agent as well.", + ) + + tod_args.add_argument( + "--user-model", + default="", + help="Define the user agent for the chat. Exactly one of this or user-model-file must be specified. Currently assumed to be the API Call creation agent as well.", + ) + + tod_args.add_argument( + "--api-resp-model", + default="", + help="Agent used for defining API response values", + ) + + tod_args.add_argument( + "--api-schema-grounding-model", + default="", + help="Agent used in first turn to grounding api call/response agents with api schemas. Will use EmptyApiSchemaAgent if both this and `--api-schemas` not set.", + ) + + tod_args.add_argument( + "--goal-grounding-model", + default="", + help="Agent used in first turn to grounding user agent with goal. Will use EmptyGoalAgent if not set", + ) + + tod_args.add_argument( + "--api-schemas", + default=None, + help="If set and `--api-schema-grounding-model` is empty, will infer `--api-schema-grounding-model` based on this and a regex on `--goal-grounding-model`. If you run into issues with parsing order of opts using this flag, just switch to `--api-schema-grounding-model`.", + ) + + @classmethod + def setup_args(cls): + # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line + parser = TodWorldParser( + True, + False, + "World for chatting with the TOD conversation structure", + ) + # Following params are same as the `eval_model` script + parser.add_argument( + "--report-filename", + type=str, + help="Saves a json file of the evaluation report either as an " + 'extension to the model-file (if begins with a ".") or a whole ' + "file path. Set to the empty string to not save at all.", + ) + parser.add_argument( + "--world-logs", + type=str, + help="Saves a jsonl file containing all of the task examples and " + "model replies.", + ) + parser.add_argument( + "--save-format", + type=str, + default="conversations", + choices=["conversations", "parlai"], + ) + parser.add_argument( + "--num-episodes", + type=int, + default=10, + help="Number of episodes to display. Set to -1 for infinity or the number of examples of the first agent with a non-unlimited number of episodes in the world.", + ) + parser.add_argument("-d", "--display-examples", type="bool", default=False) + parser.add_argument("-ltim", "--log-every-n-secs", type=float, default=10) + TodWorldLogger.add_cmdline_args(parser) + + # Following are specific to TOD World + parser.add_argument( + "--max-turns", + type=int, + default=30, + help="The max number of full turns before chat ends, excluding prompting", + ) + TodWorldScript.setup_tod_args(parser) + + return parser + + def _get_file_or_model_specifiable_agent(self, prefix, opt): + if len(opt.get(f"{prefix}_model_file", "")) > 0: + if len(opt.get(f"{prefix}_model", "")) > 0: + raise KeyError( + "Both `--{prefix}-model-file` and `--{prefix}-model` specified. Exactly one should be." + ) + model = self._make_agent( + opt, + f"{prefix}_model_file", + requireModelExists=True, + opt_key="model_file", + ) + elif len(opt.get(f"{prefix}_model", "")) > 0: + model = self._make_agent(opt, f"{prefix}_model", "") + else: + raise KeyError( + f"Both `--{prefix}-model-file` and `--{prefix}-model` specified. Neither currently set." + ) + return model + + def _get_model_or_default_agent(self, opt, key, default_class): + if len(opt.get(key, "")) > 0: + return self._make_agent(opt, key) + return default_class(opt) + + def _get_tod_agents(self, opt: Opt): + agents = [None] * tod_world.AGENT_COUNT + + agents[tod_world.USER_UTT_IDX] = self._get_file_or_model_specifiable_agent( + "user", opt + ) + + # Get system agent, nothing that api call agent currently same as system agent + system_model = self._get_file_or_model_specifiable_agent("system", opt) + agents[tod_world.SYSTEM_UTT_IDX] = system_model + agents[tod_world.API_CALL_IDX] = system_model + + agents[tod_world.API_RESP_IDX] = self._make_agent(opt, "api_resp_model") + agents[tod_world.GOAL_GROUNDING_IDX] = self._get_model_or_default_agent( + opt, "goal_grounding_model", tod_world_agents.EmptyGoalAgent + ) + + if "api_schema_grounding_model" not in opt and "api_schemas" in opt: + opt["api_schema_grounding_model"] = opt.get( + "goal_grounding_model", "" + ).replace("Goal", "ApiSchema") + + agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( + opt, + "api_schema_grounding_model", + tod_world_agents.EmptyApiSchemaAgent, + ) + + return agents + + def _make_agent(self, opt_raw, name, requireModelExists=False, opt_key="model"): + """ + Hack. + + `create_agent` expects opt[`model`] to specify the model type and we're + specifying multiple models from other opt arguments (ex. + `system_model`/`user_model` etc), so this swaps it in. + """ + opt = deepcopy(opt_raw) + opt[opt_key] = opt[name] + print(opt_key, name) + return create_agent(opt, requireModelExists) + + def _run_episode(self, opt, world, world_logger): + while not world.episode_done(): + world.parley() + world_logger.log(world) + + if opt["display_examples"]: + logging.info(world.display()) + + if opt["display_examples"]: + logging.info("-- end of episode --") + + world.reset() + world_logger.reset_world() # flush this episode + return zip(world.get_last_batch_goals(), world.get_last_batch_episode_metrics()) + + def _save_outputs(self, opt, world, logger, episode_metrics): + if is_distributed(): # flatten everything intelligently if need be + world_report = aggregate_unnamed_reports(all_gather_list(world.report())) + episode_metrics_unflattened = all_gather_list(episode_metrics) + flattened = [] + for rank_elem in episode_metrics_unflattened: + for elem in rank_elem: + flattened.append(elem) + episode_metrics = flattened + else: + world_report = world.report() + logging.report("Final report:\n" + nice_report(world_report)) + + report = dict_report(world_report) + + def get_episode_report(goal, episode_metric): + metrics_dict = dict_report(episode_metric.report()) + metrics_dict["goal"] = goal + return metrics_dict + + report["tod_metrics"] = [get_episode_report(g, e) for g, e in episode_metrics] + + if "report_filename" in opt and opt["report_filename"] is not None: + if len(world_report) == 0: + logging.warning("Report is empty; not saving report") + + report_fname = f"{opt['report_filename']}.json" + # Save report + if not is_distributed() or is_primary_worker(): + with PathManager.open(report_fname, "w") as f: + logging.info(f"Saving model report to {report_fname}") + json.dump({"opt": opt, "report": report}, f, indent=4) + f.write("\n") # for jq + + if "world_logs" in opt and opt["world_logs"] is not None: + if is_distributed(): # Save separately, then aggregate together + rank = get_rank() + log_outfile_part = ( + f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" + ) + logger.write(log_outfile_part, world, file_format=opt["save_format"]) + sync_object(None) + if is_primary_worker(): + log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" + log_outfile_metadata = ( + f"{opt['world_logs']}_{opt['save_format']}.metadata" + ) + with open(log_outfile, "w+") as outfile: + for rank in range(num_workers()): + log_outfile_part = ( + f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" + ) + with open(log_outfile_part) as infile: + for line in infile: + json_blob = json.loads(line.strip()) + if ( + len(json_blob["dialog"]) < 2 + ): # skip when we don't have generation + continue + json_blob["metadata_path"] = log_outfile_metadata + outfile.write(json.dumps(json_blob)) + outfile.write("\n") + log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata" + if rank == 0: + copyfile( + log_output_part_metadata, log_outfile_metadata + ), + os.remove(log_outfile_part) + os.remove(log_output_part_metadata) + else: + log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" + logger.write(log_outfile, world, file_format=opt["save_format"]) + + return report + + def _setup_world(self): + # setup world, manually finaggling necessary opt info as needed + self.opt["task"] = "TodWorld" + world = tod_world.TodWorld(self.opt, agents=self._get_tod_agents(self.opt)) + return world + + def run(self): + opt = self.opt + + world = self._setup_world() + logger = TodWorldLogger(opt) + + # set up logging + log_every_n_secs = opt.get("log_every_n_secs", -1) + if log_every_n_secs <= 0: + log_every_n_secs = float("inf") + log_time = TimeLogger() + + # episode counter + max_episodes = opt.get("num_episodes", -1) + if max_episodes < 0: + max_episodes = float("inf") + world_num_episodes = world.num_episodes() + if world_num_episodes > 0: + max_episodes = min(max_episodes, world_num_episodes) + + ep_count = 0 + episode_metrics = [] + while not world.epoch_done() and ep_count < max_episodes: + episode_metrics.extend(self._run_episode(opt, world, logger)) + ep_count += opt.get("batchsize", 1) + if log_time.time() > log_every_n_secs: + report = world.report() + text, report = log_time.log(ep_count, max_episodes, report) + logging.info(text) + + return self._save_outputs(opt, world, logger, episode_metrics) + + +if __name__ == "__main__": + TodWorldScript.main() diff --git a/projects/tod_simulator/README.md b/projects/tod_simulator/README.md deleted file mode 100644 index 65fbb41753a..00000000000 --- a/projects/tod_simulator/README.md +++ /dev/null @@ -1 +0,0 @@ -Page to be filled. :) diff --git a/projects/tod_simulator/world_metrics/extended_world_metrics.py b/projects/tod_simulator/world_metrics/extended_world_metrics.py new file mode 100644 index 00000000000..7aa4eca89c5 --- /dev/null +++ b/projects/tod_simulator/world_metrics/extended_world_metrics.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Metrics handlers - ie, classes that handle generations from Tod World and calculates metrics from them. + +Note that only metrics handler classes in `WORLD_METRIC_HANDLERS` are actively being recorded as metrics. +""" + +from parlai.core.message import Message +from parlai.core.metrics import ( + Metric, + AverageMetric, + normalize_answer, + BleuMetric, + SumMetric, +) +from parlai.core.tod.tod_core import ( + STANDARD_API_NAME_SLOT, + STANDARD_REQUIRED_KEY, + STANDARD_OPTIONAL_KEY, + STANDARD_API_SCHEMAS, +) +from typing import Dict, List, Optional, Tuple +from parlai.core.tod.world_metrics_handlers import ( + TodMetricsHandler, + register_metrics_handler, + _ApiCallGoalInteractionHelper, + goals_hit_helper, +) + +try: + from nltk.translate import bleu_score as nltkbleu +except ImportError: + # User doesn't have nltk installed, so we can't use it for bleu + # We'll just turn off things, but we might want to warn the user + nltkbleu = None + +################################ +# Functions and classes associated with calculating statistics between API Calls and Goals. + + +def get_req_only_goals(goals_list: List[Dict], api_schemas: List[Dict]) -> List[Dict]: + """ + Given a list of goals and a list of api schemas that say if slots are required or + optional, this function filters for the goals to be only the required ones. + + If we have no api schemas or a goal is malformed, we return the empty list. If a + goal is malformed, we print a warning, since this whole req-only goals thing is + experimental at best anyhow. + """ + if len(api_schemas) == 0: + return [] + result = [] + for goal in goals_list: + req_goals = {} + method = goal.get(STANDARD_API_NAME_SLOT, None) + if method is None: + return [] + required = [] + for schema in api_schemas: + if schema.get(STANDARD_API_NAME_SLOT, "") == method: + required = schema.get(STANDARD_REQUIRED_KEY, {}) + print("-".join(required)) + for key in required: + if key not in goal: + print(f"No required key `{key}` in goal `{goal}`") + return [] + req_goals[key] = goal[key] + if len(req_goals) > 0: + req_goals[STANDARD_API_NAME_SLOT] = method # for consistency with all. + result.append(req_goals) + return result + + +def goals_slots_helper( + goals: List[Dict], turnDict: List[Dict] +) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """ + Helper function to see how well the slot keys + slot values match between attempted + API calls and goals. + + Output is precision, recall. + """ + all_call_slots = {k: v for call in turnDict for k, v in call.items()} + all_goal_slots = {k: v for goal in goals for k, v in goal.items()} + goal_in_call = { + k: v + for k, v in all_call_slots.items() + if all_goal_slots.get(k, "definitelyNotInValuexyz") == v + } + call_in_goal = { + k: v + for k, v in all_goal_slots.items() + if all_call_slots.get(k, "definitelyNotInValuexyz") == v + } + + print(goal_in_call, all_call_slots) + + return ( + AverageMetric(len(goal_in_call), len(all_call_slots)), + AverageMetric(len(call_in_goal), len(all_goal_slots)), + ) + + +@register_metrics_handler +class LegacyGoalApiCallInteractionsMetricsHandler(_ApiCallGoalInteractionHelper): + """ + This class was reporting a few too many metrics, but is useful for test purposes, so + we're keeping it around. + + `AllGoalApiCallSuccessMetricsHandler` is the streamlined, less spammy version of + this class. + """ + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.goals = goals + self.required_goals = get_req_only_goals(goals, self.api_schemas) + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + all_goals_hit, all_goals_hit_turn_count, all_part_hit = goals_hit_helper( + self.goals, self.api_turns + ) + all_precision, all_recall = goals_slots_helper(self.goals, self.api_turns) + req_goals_hit, req_goals_hit_turn_count, req_part_hit = goals_hit_helper( + self.required_goals, self.api_turns, permissive=True + ) + req_precision, req_recall = goals_slots_helper( + self.required_goals, self.api_turns + ) + call_attempts = len(self.api_turns) + return { + "all_goals_hit": all_goals_hit, + "all_goals_hit_turn_count": all_goals_hit_turn_count, + "all_goals_fractional_hit": all_part_hit, + "all_goals_slot_precision": all_precision, + "all_goals_slot_recall": all_recall, + "req_goals_hit": req_goals_hit, + "req_goals_hit_turn_count": req_goals_hit_turn_count, + "req_goals_fractional_hit": req_part_hit, + "req_goals_slot_precision": req_precision, + "req_goals_slot_recall": req_recall, + "call_attempts": AverageMetric(call_attempts), + } + + +@register_metrics_handler +class UserGoalSlotCoverageMetricHandler(TodMetricsHandler): + """ + How well does our user simulator do at outputting utterances that goes closer to + satisfying relevant groundinged goals? Does it dump out all of the slots at once or + is it more intelligent than that? + + Since this is the user and we don't know the identity of potential slots, we ignore + the short (< 4 chars) goal slots since this tends to be things that are substrings + of other things. (Ex. "2" showing up as # of people in a reservation, but also + showing up as a phone number.) + """ + + def episode_reset(self): + self.mentioned_all_slot_values = set() + self.mentioned_req_slot_values = set() + self.all_goal_slot_values = set() + self.all_req_goal_slot_values = set() + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + """ + Parse out all the slots as a blob, filtering out for short things. + """ + required_goals = get_req_only_goals(goals, self.api_schemas) + + def get_slot_values(goal_list): + result = set() + for goal in goal_list: + for key, value in goal.items(): + if key is not STANDARD_API_NAME_SLOT and len(value) > 3: + result.add(value) + return result + + self.all_goal_slot_values = get_slot_values(goals) + self.all_req_goal_slot_values = get_slot_values(required_goals) + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + """ + Grab slots out of the user utterance based on an exact match. + """ + utterance = prefix_stripped_text + + def get_slots(utt, options): + results = set() + for option in options: + if option in utt: + results.add(option) + return results + + all_slot_values_here = get_slots(utterance, self.all_goal_slot_values) + req_slot_values_here = get_slots(utterance, self.all_req_goal_slot_values) + + self.mentioned_all_slot_values |= all_slot_values_here + self.mentioned_req_slot_values |= req_slot_values_here + + metrics = {} + metrics["user_utt_avg_any_slot"] = AverageMetric(len(all_slot_values_here)) + metrics["user_utt_avg_req_slot"] = AverageMetric(len(req_slot_values_here)) + return metrics + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + result = { + "user_any_goal_slots_recall": AverageMetric( + len(self.mentioned_all_slot_values), len(self.all_goal_slot_values) + ), + "user_req_goal_slots_recall": AverageMetric( + len(self.mentioned_req_slot_values), len(self.all_req_goal_slot_values) + ), + } + + self.mentioned_all_slot_values = set() + self.mentioned_req_slot_values = set() + return result + + +class _ExactRepeatMetricsHandler(TodMetricsHandler): + """ + Helper class for defining % of episodes where a given agent type has exactly + repeated the same utterance. + """ + + def episode_reset(self): + self.turns = [] + self.repeated = False + + def metric_key(self): + raise NotImplementedError("must implement") + + def handle_message_helper( + self, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + normalized = normalize_answer(prefix_stripped_text) + if normalized in self.turns: + self.repeated = True + self.turns.append(normalized) + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + repeat = int(self.repeated) + self.repeated = False + self.turns = [] + return {self.metric_key(): AverageMetric(repeat)} + + +@register_metrics_handler +class UserUttRepeatMetricHandler(_ExactRepeatMetricsHandler): + def metric_key(self): + return "user_utt_repeat" + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class SystemUttRepeatMetricHandler(_ExactRepeatMetricsHandler): + def metric_key(self): + return "sys_utt_repeat" + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class _Bleu3MetricsHandler(TodMetricsHandler): + """ + For a given agent, this calculates the Bleu-3 of a new turn against prior turns. + + This is an alternate metric for repetativeness + """ + + def episode_reset(self): + self.turns = [] + + def metric_key(self): + raise NotImplementedError("must implement") + + def handle_message_helper( + self, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + here = [normalize_answer(x) for x in prefix_stripped_text.split(" ")] + score = 1 + if len(self.turns) > 0: + score = nltkbleu.corpus_bleu( + [self.turns], + [here], + smoothing_function=nltkbleu.SmoothingFunction(epsilon=1e-12).method1, + weights=[1.0 / 3.0] * 3, + ) + self.turns.append(here) + return {self.metric_key(): BleuMetric(score)} + + +@register_metrics_handler +class UserUttSelfBleu3MetricHandler(_Bleu3MetricsHandler): + def metric_key(self): + return "user_utt_self_bleu3" + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class SystemUttSelfBleu3MetricHandler(_Bleu3MetricsHandler): + def metric_key(self): + return "sys_utt_self_bleu3" + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class ApiCallMalformedMetricHandler(TodMetricsHandler): + def episode_reset(self): + self.api_schemas = [] + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + if STANDARD_API_SCHEMAS in message["text"]: + return # Happens for API call groundingion, so it's fine + if len(api_call) == 0: + return + if STANDARD_API_NAME_SLOT not in api_call: + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_hasSlotsButNoApiNameSlot_count": SumMetric(1), + } + method = api_call[STANDARD_API_NAME_SLOT] + + method_found = False + if len(self.api_schemas) > 0: + for schema in self.api_schemas: + if method == schema.get(STANDARD_API_NAME_SLOT, ""): + method_found = True + check = api_call.keys() + required = set(schema.get(STANDARD_REQUIRED_KEY, [])) + required.add(STANDARD_API_NAME_SLOT) + for req in required: + if req not in check: # miissing required + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_missingRequiredSlot_count": SumMetric(1), + } + opt_count = 0 + for opt in schema.get(STANDARD_OPTIONAL_KEY, []): + if opt in check: + opt_count += 1 + if opt_count + len(required) != len(check): + # have extra APIs that are not + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_hasExtraParams_count": SumMetric(1), + } + break + if method_found: + return { + "apiCall_wellFormed": AverageMetric(1), + "apiCall_wellFormed_count": SumMetric(1), + } + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_methodDNE_count": SumMetric(1), + } + + +@register_metrics_handler +class PseudoInformMetricsHandler(TodMetricsHandler): + """ + Pseudo-inform rate. + """ + + def episode_reset(self): + self.api_resp_slots = {} + + def handle_api_resp( + self, message: Message, api_resp: Dict + ) -> Optional[Dict[str, Metric]]: + self.api_resp_slots.update(api_resp) + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + count = 0 + for val in self.api_resp_slots.values(): + if val in prefix_stripped_text: + count += 1 + result = {"pseudo_inform_allSysTurns": AverageMetric(count)} + if len(self.api_resp_slots) > 0: + result["pseudo_inform_postApiRespSysTurns"] = AverageMetric(count) + return result diff --git a/tests/tod/test_tod_world_and_script.py b/tests/tod/test_tod_world_and_script.py new file mode 100644 index 00000000000..d3f764bc13b --- /dev/null +++ b/tests/tod/test_tod_world_and_script.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests tod world, notably for batching. +""" + +import copy +import unittest + +import parlai.core.tod.tod_test_utils.test_agents as test_agents +import parlai.core.tod.tod_core as tod_core +import parlai.scripts.tod_world_script as tod_world_script +from parlai.core.tod.tod_agents import StandaloneApiAgent + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def _get_tod_agents(self, opt): + """ + Hack so we can separate out logic of making sure agent parsing is correct. + """ + if hasattr(self, "agents"): + return self.agents + return super()._get_tod_agents(opt) + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.logger = logger + + +class TodWorldInScriptTestBase(unittest.TestCase): + def add_tod_world_opts(self, base_opts): + """ + Convenience since we're initing the opt directly without parlai parser. + """ + opts = copy.deepcopy(base_opts) + opts["datatype"] = "DUMMY" + opts["datafile"] = "DUMMY" + opts["episodes_randomization_seed"] = -1 + opts["standalone_api_file"] = test_agents.API_DATABASE_FILE + opts["exact_api_call"] = True + opts["log_keep_fields"] = "all" + opts["display_examples"] = False + opts[ + "include_api_schemas" + ] = True # do this to test_agents.make sure they're done correctly. + return opts + + def setup_agents(self, added_opts): + full_opts = self.add_tod_world_opts(added_opts) + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + test_agents.ApiSchemaAgent(full_opts), + test_agents.GoalAgent(full_opts), + ] + return agents, full_opts + + def _test_roundDataCorrect(self): + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE_BS) + + def _check_correctness_from_script_logs( + self, script, opt, process_round_utts=lambda x: x + ): + """ + Last argument is only relevant for the max_turn test. + """ + max_rounds = opt[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] + # there's something funky with logger.get_log() that inserts a space, but not gonna worry about it for now + logs = [x for x in script.logger.get_logs() if len(x) > 0] + for episode_idx in range(max_episodes): + episode_from_world = logs[episode_idx] + # first round is context + context = episode_from_world[0] + self.assertEquals( + context[0]["text"], + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(max_rounds) + ), + ) + self.assertEquals( + context[3]["text"], + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ), + ) + # Check the rest + world_utts = [[x["text"] for x in turn] for turn in episode_from_world[1:]] + # ... ignore the last DONE turn here cause it's not that important + + self.assertEquals( + world_utts[:-1], + process_round_utts( + test_agents.get_round_utts(episode_idx, max_rounds)[:-1] + ), + ) + + +class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + def _test_roundDataCorrect_helper(self, config): + config["batchsize"] = 1 + config["max_turns"] = 10 + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_correctness_from_script_logs(script, opt) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + def test_max_turn(self): + self._test_max_turn_helper(4) + self._test_max_turn_helper(7) + + def _test_max_turn_helper(self, max_turns): + config = {} + config["batchsize"] = 1 + config["max_turns"] = max_turns + config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + + def filter_round_utt(utts): + # tad imprecise, but more important that it does stop. + # subtract 1 for the context turn, then 1 cause there's an off by one somewhere + return utts[: max_turns - 2] + + self._check_correctness_from_script_logs(script, opt, filter_round_utt) + + +class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + def _test_roundDataCorrect_helper(self, config): + config["batchsize"] = 4 + config["max_turns"] = 10 + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_correctness_from_script_logs(script, opt) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TodWorldTestSingleDumpAgents(TodWorldInScriptTestBase): + def setup_agents(self, added_opts, api_agent, goal_agent): + full_opts = self.add_tod_world_opts(added_opts) + full_opts["fixed_response"] = "USER: [DONE]" + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + api_agent(full_opts), + goal_agent(full_opts), + ] + return agents, full_opts + + def test_SingleGoalApiResp_noBatching(self): + config = {} + config["batchsize"] = 1 + config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + single_agents, opt = self.setup_agents( + config, test_agents.SingleApiSchemaAgent, test_agents.SingleGoalAgent + ) + single_script = TestTodWorldScript(opt) + single_script.agents = single_agents + single_script.run() + single_logs = [x for x in single_script.logger.get_logs() if len(x) > 0] + + multi_agents, opt = self.setup_agents( + config, test_agents.ApiSchemaAgent, test_agents.GoalAgent + ) + multi_script = TestTodWorldScript(opt) + multi_script.agents = multi_agents + multi_script.run() + multi_logs = [x for x in single_script.logger.get_logs() if len(x) > 0] + + single_idx = 0 + for multi_log in multi_logs: + context = multi_log[0] + goals = tod_core.SerializationHelpers.str_to_goals( + context[3]["text"][len("GOAL:") :].strip() + ) + for goal in goals: + single_context = single_logs[single_idx][0] + single_goal = tod_core.SerializationHelpers.str_to_goals( + single_context[3]["text"][len("GOAL:") :].strip() + ) + self.assertEqual(len(single_goal), 1) + self.assertEquals(goal, single_goal[0]) + single_des = tod_core.SerializationHelpers.str_to_api_schemas( + single_context[0]["text"][len("APIS:") :].strip() + ) + self.assertEqual(len(single_des), 1) + self.assertEqual(single_goal[0]["api_name"], single_des[0]["api_name"]) + + single_idx += 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py new file mode 100644 index 00000000000..0367f655994 --- /dev/null +++ b/tests/tod/test_tod_world_metrics.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests different (more complicated) slot metrics. +""" + +import unittest + +from parlai.core.tod.tod_core import ( + STANDARD_API_NAME_SLOT, + STANDARD_REQUIRED_KEY, + STANDARD_OPTIONAL_KEY, + TodStructuredRound, + TodStructuredEpisode, + TodAgentType, + TOD_AGENT_TYPE_TO_PREFIX, +) +from parlai.core.tod.world_metrics import ( + TodMetrics, +) +from parlai.core.tod.world_metrics_handlers import ( + METRICS_HANDLER_CLASSES_TEST_REGISTRY, +) + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + +GOAL__SINGLE_ONE_KEY = [{STANDARD_API_NAME_SLOT: "name", "a": "1"}] +GOAL__SINGLE_THREE_KEYS = [ + {STANDARD_API_NAME_SLOT: "name", "a": "1", "b": "2", "c": "3"} +] +GOAL__HARD = [ + { + STANDARD_API_NAME_SLOT: "otherName", + "w": "1", + "x": "2", + "y": "3", + "z": "will_be_missing", + "diff": "right", + } +] + +API_CALL__NO_API_NAME_SLOT = {"random": "blah"} +API_CALL__API_NAME_DNE = {STANDARD_API_NAME_SLOT: "not_an_api_name"} +API_CALL__VALID_NAME_BUT_EMPTY = {STANDARD_API_NAME_SLOT: "name"} +API_CALL__SINGLE_ONE_KEY = GOAL__SINGLE_ONE_KEY[0] +API_CALL__SINGLE_ONE_KEY_WITH_OPT = {**GOAL__SINGLE_ONE_KEY[0], **{"c": "3"}} +API_CALL__SINGLE_ONE_KEY_WITH_OPT_AND_NONVALID = { + **GOAL__SINGLE_ONE_KEY[0], + **{"c": "3", "nonExistent": "blah"}, +} +API_CALL__FUNKY_AGAINST_HARD = { + STANDARD_API_NAME_SLOT: "otherName", + "w": "1", + "x": "2", + "y": "3", + "diff": "wrong", +} + +API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY = [ + { + STANDARD_API_NAME_SLOT: "name", + STANDARD_REQUIRED_KEY: ["a"], + STANDARD_OPTIONAL_KEY: [], + } +] + +API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS = [ + { + STANDARD_API_NAME_SLOT: "name", + STANDARD_REQUIRED_KEY: ["a"], + STANDARD_OPTIONAL_KEY: ["b", "c", "d"], + } +] + +API_SCHEMA__ONE_CALL_HARD = [ + { + STANDARD_API_NAME_SLOT: "otherName", + STANDARD_REQUIRED_KEY: ["w", "x"], + STANDARD_OPTIONAL_KEY: ["y", "z", "diff"], + } +] + + +class TodMetricsTestHelper: + def __init__(self, e: TodStructuredEpisode): + self.m = TodMetrics() + self.m.handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] # run on ALL + self.e = e + + def _process(self, t: TodAgentType, text: str): + self.m.handle_message({"text": f"{TOD_AGENT_TYPE_TO_PREFIX[t]}{text}"}, t) + + def run(self): + self._process(TodAgentType.API_SCHEMA_GROUNDING_AGENT, self.e.api_schemas_utt) + self._process(TodAgentType.GOAL_GROUNDING_AGENT, self.e.goal_calls_utt) + + for r in self.e.rounds: + self._process(TodAgentType.USER_UTT_AGENT, r.user_utt) + self._process(TodAgentType.API_CALL_AGENT, r.api_call_utt) + self._process(TodAgentType.API_RESP_AGENT, r.api_resp_utt) + self._process(TodAgentType.SYSTEM_UTT_AGENT, r.sys_utt) + + self.m.episode_reset() + + def report(self): + return self.m.report() + + +class TestApiGoalHitMetricsHandler(unittest.TestCase): + def __helper(self, api_schemas_machine, goal_calls_machine, single_turn_api_call): + e = TodStructuredEpisode( + api_schemas_machine=api_schemas_machine, + goal_calls_machine=goal_calls_machine, + rounds=[TodStructuredRound(api_call_machine=single_turn_api_call)], + ) + helper = TodMetricsTestHelper(e) + helper.run() + result = helper.report() + return result + + def test_one_goal_only_req(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY, + goal_calls_machine=GOAL__SINGLE_ONE_KEY, + single_turn_api_call=API_CALL__SINGLE_ONE_KEY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 1) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) + self.assertAlmostEqual(result["all_goals_slot_recall"], 1) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + def test_one_goal_api_name_missing_slots(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY, + goal_calls_machine=GOAL__SINGLE_ONE_KEY, + single_turn_api_call=API_CALL__VALID_NAME_BUT_EMPTY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) # api_name + self.assertAlmostEqual(result["all_goals_slot_recall"], 0.5) + + self.assertAlmostEqual(result["req_goals_hit"], 0) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 0.5) + + def test_one_goal_with_opts(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS, + goal_calls_machine=GOAL__SINGLE_THREE_KEYS, + single_turn_api_call=API_CALL__SINGLE_ONE_KEY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) + self.assertAlmostEqual(result["all_goals_slot_recall"], 0.5) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + def test_hard_case(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_HARD, + goal_calls_machine=GOAL__HARD, + single_turn_api_call=API_CALL__FUNKY_AGAINST_HARD, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 0.8) + self.assertAlmostEqual(result["all_goals_slot_recall"], 2.0 / 3.0) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 0.6) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + +class TestApiCallMalformedMetricsHandler(unittest.TestCase): + def __helper(self, single_turn_api_call): + e = TodStructuredEpisode( + api_schemas_machine=API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS, + rounds=[TodStructuredRound(api_call_machine=single_turn_api_call)], + ) + helper = TodMetricsTestHelper(e) + helper.run() + return helper.report() + + def test_no_api_name_slot(self): + result = self.__helper(API_CALL__NO_API_NAME_SLOT) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_hasSlotsButNoApiNameSlot_count"], 1) + + def test_api_name_DNE(self): + result = self.__helper(API_CALL__API_NAME_DNE) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_methodDNE_count"], 1) + + def test_missing_required_slot(self): + result = self.__helper(API_CALL__VALID_NAME_BUT_EMPTY) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_missingRequiredSlot_count"], 1) + + def test_has_single_required_slot(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY) + self.assertEqual(result["apiCall_wellFormed"], 1) + self.assertEqual(result["apiCall_wellFormed_count"], 1) + + def test_has_valid_optional_slot(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY_WITH_OPT) + self.assertEqual(result["apiCall_wellFormed"], 1) + self.assertEqual(result["apiCall_wellFormed_count"], 1) + + def test_has_invalid_extra_slots(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY_WITH_OPT_AND_NONVALID) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_hasExtraParams_count"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py new file mode 100644 index 00000000000..ffac2a25e12 --- /dev/null +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +sTests tod world, notably for batching. +""" + +import copy +import unittest + +from parlai.core.metrics import dict_report +from parlai.core.opt import Opt +from parlai.core.tod.tod_core import SerializationHelpers +import parlai.core.tod.tod_test_utils.test_agents as test_agents +from parlai.core.tod.world_metrics_handlers import ( + METRICS_HANDLER_CLASSES_TEST_REGISTRY, +) +import parlai.scripts.tod_world_script as tod_world_script + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + +NUM_EPISODES = 35 + +TEST_SETUP = { + "api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent", + "goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent", + "user_model": "parlai.core.tod.tod_test_utils.test_agents:UserUttAgent", + "system_model": "parlai.core.tod.tod_test_utils.test_agents:ApiCallAndSysUttAgent", + "api_resp_model": "fixed_response", + test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES, +} +TEST_SETUP_BROKEN_USER_SYSTEM = { + "api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent", + "goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent", + "user_model": "fixed_response", + "system_model": "fixed_response", + "api_resp_model": "fixed_response", + test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES, +} + +TEST_SETUP_EMPTY_APISCHEMA = copy.deepcopy(TEST_SETUP) +TEST_SETUP_EMPTY_APISCHEMA[ + "api_schema_grounding_model" +] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent" + +TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA = copy.deepcopy( + TEST_SETUP_BROKEN_USER_SYSTEM +) +TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA[ + "api_schema_grounding_model" +] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent" + +DATATYPE = "valid" + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def __init__(self, opt: Opt): + opt["datatype"] = DATATYPE + # none of the below matter, but need to set to keep other code happy. + opt["log_keep_fields"] = "all" + opt["display_examples"] = False + + super().__init__(opt) + + def _setup_world(self): + world = super()._setup_world() + for i in range(len(world.batch_tod_world_metrics)): + world.batch_tod_world_metrics[i].handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] + return world + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.logger = logger + self.episode_metrics = episode_metrics + + +class TodMetricsInScriptTests(unittest.TestCase): + def test_all_goals_hit_all_success(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=1, num_episodes=1, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=1, num_episodes=32, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=32, num_episodes=8, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=32, num_episodes=33, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=1, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_fail(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=1, + num_episodes=1, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=1, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=33, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=0, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_success_emptySchema(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=1, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=32, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=32, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=33, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=1, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_fail_emptySchema(self): + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=1, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=33, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=0, + target_metrics_length=NUM_EPISODES, + ) + + def _check_all_goals_hit_by_opt_and_batchsize( + self, + opt, + batchsize, + num_episodes, + target_all_goals_hit, + target_metrics_length=None, + ): + opt = copy.deepcopy(opt) + opt["batchsize"] = batchsize + opt["num_episodes"] = num_episodes + report, metrics = self._run_opt_get_report(opt) + self.assertEqual(report.get("all_goals_hit"), target_all_goals_hit) + metrics_comp_length = num_episodes + if target_metrics_length: + metrics_comp_length = target_metrics_length + self.assertEqual(len(metrics), metrics_comp_length) + + def _run_opt_get_report(self, opt): + script = TestTodWorldScript(opt) + script.run() + + def get_episode_report(goal, episode_metric): + metrics_dict = dict_report(episode_metric.report()) + metrics_dict["goal"] = goal + return metrics_dict + + return dict_report(script.world.report()), [ + get_episode_report(g, e) for g, e in script.episode_metrics + ] + + def test_apiCallAttempts_usingGold(self): + opt = copy.deepcopy(TEST_SETUP) + opt["batchsize"] = 1 + opt["num_episodes"] = -1 + _, metrics = self._run_opt_get_report(opt) + for metric in metrics: + self.assertEqual( + len( + SerializationHelpers.str_to_goals( + metric["goal"]["text"][len("GOALS: ") :] + ) + ), + metric["call_attempts"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py new file mode 100644 index 00000000000..9862b3c824f --- /dev/null +++ b/tests/tod/test_tod_world_script_metrics.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests tod world, notably for batching. +""" + +import copy +import unittest + +import parlai.core.tod.tod_test_utils.test_agents as test_agents +import parlai.scripts.tod_world_script as tod_world_script +from parlai.core.tod.tod_agents import StandaloneApiAgent +from parlai.core.tod.world_metrics_handlers import ( + METRICS_HANDLER_CLASSES_TEST_REGISTRY, +) +from parlai.core.metrics import dict_report + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def _get_tod_agents(self, opt): + """ + Hack so we can separate out logic of making sure agent parsing is correct. + """ + if hasattr(self, "agents"): + return self.agents + return super()._get_tod_agents(opt) + + def _setup_world(self): + world = super()._setup_world() + for i in range(len(world.batch_tod_world_metrics)): + world.batch_tod_world_metrics[i].handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] + return world + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.episode_metrics = episode_metrics + + +class TodWorldInScriptTestBase(unittest.TestCase): + def add_tod_world_opts(self, base_opts): + """ + Convenience since we're initing the opt directly without parlai parser. + """ + opts = copy.deepcopy(base_opts) + opts["datatype"] = "DUMMY" + opts["datafile"] = "DUMMY" + opts["standalone_api_file"] = test_agents.API_DATABASE_FILE + opts["exact_api_call"] = True + opts["log_keep_fields"] = "all" + opts["display_examples"] = False + opts[ + "include_api_schemas" + ] = True # do this to test_agents.make sure they're done correctly. + return opts + + def setup_agents(self, added_opts): + full_opts = self.add_tod_world_opts(added_opts) + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + test_agents.ApiSchemaAgent(full_opts), + test_agents.GoalAgent(full_opts), + ] + return agents, full_opts + + def _run_test(self): + self._run_test_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE_BS) + + def _run_test_helper(self, config_base): + config = copy.deepcopy(config_base) + config["use_broken_mock_api_calls"] = True + add = self.config_args() + for key in add: + config[key] = add[key] + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_metrics_correct(script, opt) + + def _check_metrics_correct(self, script, opt): + """ + Last argument is only relevant for the max_turn test. + """ + max_rounds = opt[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] + episode_metrics = script.episode_metrics + for episode_idx, episode in enumerate(episode_metrics): + # if episode_idx >= max_episodes: + # break + # See how we make broken mock api calls in the test_agents. + goal, episode_metric = episode + episode_metric = dict_report(episode_metric.report()) + self.assertAlmostEqual( + episode_metric["all_goals_hit"], + not test_agents.episode_has_broken_api_turn(episode_idx, max_rounds), + ) + broken_episodes = sum( + [ + test_agents.episode_has_broken_api_turn(i, max_rounds) + for i in range(max_episodes) + ] + ) + report = dict_report(script.world.report()) + self.assertAlmostEqual( + report["all_goals_hit"], + float(max_episodes - broken_episodes) / max_episodes, + ) + + +class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + def config_args(self): + config = {} + config["batchsize"] = 1 + config["max_turns"] = 10 + return config + + def test_metricsCorrect(self): + self._run_test() + + +class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + def config_args(self): + config = {} + config["batchsize"] = 4 + config["max_turns"] = 10 + return config + + def test_metricsCorrect(self): + self._run_test() + + +if __name__ == "__main__": + unittest.main() From 0e3f492dfffb4746a17ce367d97e91af895dfd49 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 11:59:27 -0800 Subject: [PATCH 08/23] hmmm... hoping stacks don't bite me. (change that was kept in upper diff in stack, but lost from this one --- parlai/core/tod/tod_core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/parlai/core/tod/tod_core.py b/parlai/core/tod/tod_core.py index 76ee8005a74..e53be5c7376 100644 --- a/parlai/core/tod/tod_core.py +++ b/parlai/core/tod/tod_core.py @@ -160,7 +160,10 @@ def inner_list_join(cls, values): @classmethod def inner_list_split(cls, s): - return s.split(", ") + split = s.split(", ") + if len(split) == 1: + return split[0] + return set(split) @classmethod def maybe_inner_list_join(cls, values): @@ -193,7 +196,7 @@ def str_to_api_dict(cls, string): continue name, value = slot_str.split(" = ", 1) name = name.strip() - value = value.strip() + value = SerializationHelpers.inner_list_split(value.strip()) result[name] = value return result From 37aced299af86863069b7655ef760033dcd9a4e5 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:01:13 -0800 Subject: [PATCH 09/23] minor, remove commented out print --- parlai/core/tod/tod_test_utils/test_agents.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b1339052764..b7639efea36 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -171,7 +171,6 @@ def setup_episodes(self, _): ), ) ) - # print(result, self.opt) return result def get_id_task_prefix(self): From b05930fa1238faa1f125ae80ee93e212ce1b8ada Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:16:33 -0800 Subject: [PATCH 10/23] comment --- parlai/scripts/tod_world_script.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index af4f963b71f..81b38b2b620 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -5,6 +5,13 @@ # LICENSE file in the root directory of this source tree. """ Base script for running TOD model-model chats. + +For example, to extract gold ground truth data from Google SGD, run + +``` +python -u -m parlai.scripts.tod_world_script --api-schema-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent --goal-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent --user-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent --system-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent --api-resp-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent -dt valid --num-episodes -1 --episodes-randomization-seed 42 --world-logs gold-valid +``` + """ import json @@ -108,7 +115,7 @@ class TodWorldScript(ParlaiScript): @classmethod def setup_tod_args(cls, parser: ParlaiParser): tod_args = parser.add_argument_group( - "TOD World Script Agent arguments. NOTE: Agents setup with this path will be able to take command line arguments, whereas those set from `-o` or `--init-opt` will not." + "TOD World Script Agent arguments. NOTE: If there are issues with invoking downstream opts of agents specified here sometimes you will have more luck with `python -u -m parlai.scripts.tod_world_script` than `parlai tod_world_script`." ) tod_args.add_argument( "--system-model-file", From 5086e8507c785c53f1b78bc5060b4e7e393ab430 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:19:08 -0800 Subject: [PATCH 11/23] more comment updates (not sure if it actually helps clarity..) --- parlai/scripts/tod_world_script.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 81b38b2b620..42199bc4ff3 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -6,12 +6,11 @@ """ Base script for running TOD model-model chats. -For example, to extract gold ground truth data from Google SGD, run +For example, to extract gold ground truth data from Google SGD, run ``` python -u -m parlai.scripts.tod_world_script --api-schema-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent --goal-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent --user-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent --system-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent --api-resp-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent -dt valid --num-episodes -1 --episodes-randomization-seed 42 --world-logs gold-valid ``` - """ import json From 9a25fc56b94b7dc47961c8e04b63a3848fc073df Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 12:40:39 -0800 Subject: [PATCH 12/23] [TOD][Dataset][Easy] Google SGD in TOD Conversations format Refactor Google SGD away from old format into TOD Conversations format. Datasets added in this substack: * *Google SGD* * Google SGD Simulation Splits (In-domain, Out-domain) * MetalWoz * MSR_E2E * Multidogo * MultiWoz V2.2 * Taskmaster * Taskmaster2 * Taskmaster3 (TicketTalk) Test plan: Regression test, `parlai dd` of dataset --- parlai/tasks/google_sgd/agents.py | 354 ++++++++++-------- parlai/tasks/google_sgd/build.py | 16 +- parlai/tasks/google_sgd/test.py | 10 +- .../google_sgd_UserSimulatorTeacher_test.yml | 51 +++ .../google_sgd_UserSimulatorTeacher_train.yml | 49 +++ .../google_sgd_UserSimulatorTeacher_valid.yml | 46 +++ .../tasks/google_sgd/test/google_sgd_test.yml | 108 ++---- .../google_sgd/test/google_sgd_train.yml | 85 ++--- .../google_sgd/test/google_sgd_valid.yml | 101 ++--- 9 files changed, 466 insertions(+), 354 deletions(-) create mode 100644 parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml create mode 100644 parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml create mode 100644 parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml diff --git a/parlai/tasks/google_sgd/agents.py b/parlai/tasks/google_sgd/agents.py index 12e55a5deff..bfcdf02b3dc 100644 --- a/parlai/tasks/google_sgd/agents.py +++ b/parlai/tasks/google_sgd/agents.py @@ -8,192 +8,228 @@ Google The Schema-Guided Dialogue(SGD) Dataset implementation for ParlAI. """ -import os +import glob import json -from parlai.core.opt import Opt -from parlai.core.teachers import DialogTeacher -from parlai.utils.misc import warn_once -from parlai.core.message import Message -from parlai.core.metrics import AverageMetric, BleuMetric -from parlai.utils.io import PathManager +import os +from typing import Optional import parlai.tasks.google_sgd.build as build_ +import parlai.core.tod.tod_core as tod +import parlai.core.tod.tod_agents as tod_agents +from parlai.core.tod.tod_core import SerializationHelpers +from parlai.core.params import ParlaiParser +from parlai.core.opt import Opt +from parlai.utils.io import PathManager -class Text2API2TextTeacher(DialogTeacher): - """ - Teacher which produces both API calls and NLG responses. - """ +class GoogleSGDParser(tod_agents.TodStructuredDataParser): + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + parser = super().add_cmdline_args(parser, partial_opt) + parser.add_argument( + "--delex", type="bool", default=False, help="Delexicalize labels" + ) + parser.add_argument( + "--filter-dialogue-by-id", + default="", + type=str, + help="Path to a json file of `dialogue_id`s for which we will filter from. Assumes it will contain a map where the keys are a fold and the value is a list of ids", + ) + return parser def __init__(self, opt: Opt, shared=None): - self.fold = opt['datatype'].split(':')[0] - opt['datafile'] = self.fold - self.dpath = os.path.join(opt['datapath'], 'google_sgd') + self.fold = self.get_fold(opt) + opt["datafile"] = self.fold + self.dpath = os.path.join(opt["datapath"], "google_sgd") if shared is None: - warn_once( - "Google SGD is a beta dataset, and format may significantly change." - ) + # full initialize the teacher as this is not a clone build_.build(opt) super().__init__(opt, shared) + def get_fold(self, opt): + return opt["datatype"].split(":")[0] + def _load_data(self, fold): - dataset_fold = 'dev' if fold == 'valid' else fold + dataset_fold = "dev" if fold == "valid" else fold fold_path = os.path.join(self.dpath, dataset_fold) - schema_file = os.path.join(fold_path, 'schema.json') - with PathManager.open(schema_file, 'r') as f: + schema_file = os.path.join(fold_path, "schema.json") + with PathManager.open(schema_file, "r") as f: schema_lookup = {} for schema in json.load(f): - schema_lookup[schema['service_name']] = schema - - dialogs = [] - for file_id in range(1, build_.fold_size(dataset_fold) + 1): - filename = os.path.join(fold_path, f'dialogues_{file_id:03d}.json') - with PathManager.open(filename, 'r') as f: - dialogs += json.load(f) - return schema_lookup, dialogs - - def _get_api_call_and_results(self, sys_turn, schema_lookup): + schema_lookup[schema["service_name"]] = schema + + dialogues = [] + for filename in glob.glob(f"{fold_path}/dialogues*.json"): + with PathManager.open(filename, "r") as f: + dialogues += json.load(f) + + filter_path = self.opt.get("filter_dialogue_by_id", "") + if len(filter_path) > 0: + filtered = [] + with open(filter_path) as f: + dialogues_to_get = json.load(f)[fold] + for dialogue in dialogues: + if dialogue["dialogue_id"] in dialogues_to_get: + filtered.append(dialogue) + assert len(filtered) == len( + dialogues_to_get + ), f"Different number of dialogues found than requested. Are you sure you've got the right form of Google SGD? Did you filter for dialogue ids correctly? len(filtered) = {len(filtered)}, len(dialogues_to_get) = {len(dialogues_to_get)}" + dialogues = filtered + return schema_lookup, dialogues + + def _get_api_call_and_results(self, sys_turn): api_call = {} api_resp = {} - for frame in sys_turn['frames']: - if 'service_call' in frame: + for frame in sys_turn["frames"]: + if "service_call" in frame: # API CALL - method = frame['service_call']['method'] - for slot_type, slot_value in frame['service_call'][ - 'parameters' + for slot_type, slot_value in frame["service_call"][ + "parameters" ].items(): - api_call[f'{method}.{slot_type}'] = slot_value - assert 'service_results' in frame + if slot_value: + api_call[ + f"{slot_type.strip()}" + ] = SerializationHelpers.inner_list_join(slot_value) + api_call[tod.STANDARD_API_NAME_SLOT] = frame["service_call"]["method"] + assert "service_results" in frame # API Resp - if 'actions' in frame: - for action in frame['actions']: - slot_type = action['slot'] - slot_value = action['canonical_values'] - api_resp[slot_type] = slot_value + if "service_results" in frame: + api_resp = {} + service_results = frame["service_results"] + if len(service_results) > 0: + for key, value in service_results[0].items(): + api_resp[key] = SerializationHelpers.inner_list_join(value) return api_call, api_resp - def custom_evaluation( - self, teacher_action: Message, labels, model_response: Message - ): - resp = model_response.get('text') - if not resp: - return - - if teacher_action['type'] == 'apicall' and resp.startswith('apicall: '): - gold = teacher_action['slots'] - slot_strs = resp[9:].split(' ; ') - parsed = {} - for slot_str in slot_strs: - if ' = ' not in slot_str: - if slot_str != '': - # syntactically invalid generations should count against us - self.metrics.add('slot_p', AverageMetric(0)) - continue - name, value = slot_str.split(' = ') - parsed[name] = value - - # slot precision - for k, v in parsed.items(): - self.metrics.add('slot_p', AverageMetric(v == gold.get(k))) - # slot recall - for k, v in gold.items(): - self.metrics.add('slot_r', AverageMetric(v == parsed.get(k))) - elif teacher_action['type'] == 'apiresp': - delex_resp = self._delex(resp, teacher_action['slots']) - delex_label = self._delex(labels[0], teacher_action['slots']) - self.metrics.add( - 'delex_bleu', BleuMetric.compute(delex_resp, [delex_label]) - ) - - def _delex(self, text, slots): - delex = text - for slot, values in slots.items(): - assert isinstance(values, list) - for value in values: - delex = delex.replace(value, slot) - return delex - - def _api_dict_to_str(self, apidict): - return ' ; '.join(f'{k} = {v}' for k, v in apidict.items()) - - def setup_data(self, fold): - schema_lookup, dialogs = self._load_data(fold) - for dialog in dialogs: - # services = dialog['services'] - turns = dialog['turns'] - num_turns = len(turns) - for turn_id in range(0, num_turns, 2): - is_first_turn = turn_id == 0 - + def _get_apis_in_domain(self, schema, domain): + """ + Google SGD includes extra information with the call, so remove these. + """ + result = {} + for intent in schema[domain].get("intents", {}): + here = {} + if "required_slots" in intent and len(intent["required_slots"]) > 0: + here[tod.STANDARD_REQUIRED_KEY] = intent["required_slots"] + if "optional_slots" in intent and len(intent["optional_slots"]) > 0: + here[tod.STANDARD_OPTIONAL_KEY] = intent["optional_slots"] + if "result_slots" in intent: + here["results"] = intent["result_slots"] + result[intent["name"]] = here + return result + + def _get_intent_groundinging(self, schema, domains): + """ + Returns map where keys are intents and values are names of required/optional + slots. + + We do not care about `result_slots` or default values of optional slots. + """ + result = [] + for domain in domains: + apis = self._get_apis_in_domain(schema, domain) + for intent, params in apis.items(): + here = {} + here[tod.STANDARD_API_NAME_SLOT] = intent + if tod.STANDARD_REQUIRED_KEY in params: + here[tod.STANDARD_REQUIRED_KEY] = params[tod.STANDARD_REQUIRED_KEY] + if ( + tod.STANDARD_OPTIONAL_KEY in params + and len(params[tod.STANDARD_OPTIONAL_KEY]) > 0 + ): + here[tod.STANDARD_OPTIONAL_KEY] = params[ + tod.STANDARD_OPTIONAL_KEY + ].keys() + result.append(here) + return result + + def _get_all_service_calls(self, turns): + """ + Searches through all turns in a dialogue for any service calls, returns these. + """ + results = [] + for turn in turns: + for frame in turn["frames"]: + if "service_call" in frame: + call = frame["service_call"] + item = call["parameters"] + item[tod.STANDARD_API_NAME_SLOT] = call["method"] + results.append(item) + return results + + def setup_episodes(self, fold): + """ + Parses Google SGD episodes into TodStructuredEpisode. + """ + schema_lookup, dialogues = self._load_data(fold) + result = [] + for dialogue in dialogues: + domains = {s.split("_")[0].strip() for s in dialogue["services"]} + turns = dialogue["turns"] + rounds = [] + for turn_id in range(0, len(turns), 2): user_turn = turns[turn_id] sys_turn = turns[turn_id + 1] - api_call, api_results = self._get_api_call_and_results( - sys_turn, schema_lookup + api_call, api_results = self._get_api_call_and_results(sys_turn) + r = tod.TodStructuredRound( + user_utt=user_turn["utterance"], + api_call_machine=api_call, + api_resp_machine=api_results, + sys_utt=sys_turn["utterance"], ) - call_str = self._api_dict_to_str(api_call) - resp_str = self._api_dict_to_str(api_results) - if not api_call and not api_results: - # input: user_turn, output: sys_turn - yield { - 'text': user_turn['utterance'], - 'label': sys_turn['utterance'], - 'type': 'text', - }, is_first_turn - elif not api_call and api_results: - yield { - 'text': f"{user_turn['utterance']} api_resp: {resp_str}", - 'label': sys_turn['utterance'], - 'type': 'apiresp', - 'slots': api_results, - }, is_first_turn - elif api_call and api_results: - # input: user_turn, output: api_call - yield { - 'text': user_turn['utterance'], - 'label': f'apicall: {call_str}', - 'type': 'apicall', - 'slots': api_call, - }, is_first_turn - - # system turn, input : api results, output : assistant turn - yield { - 'text': f"api_resp: {resp_str}", - 'label': sys_turn['utterance'], - 'type': 'apiresp', - 'slots': api_results, - }, False - else: - assert ( - api_call and api_results - ), "API call without API results! Check Dataset!" - - -class Text2TextTeacher(Text2API2TextTeacher): - """ - Text-only teacher (with no API calls or slots) - """ - - def setup_data(self, fold): - schema_lookup, dialogs = self._load_data(fold) - for dialog in dialogs: - turns = dialog['turns'] - num_turns = len(turns) - for turn_id in range(0, num_turns, 2): - if turn_id == 0: - is_first_turn = True - else: - is_first_turn = False + rounds.append(r) + # Now that we've got the rounds, make the episode + episode = tod.TodStructuredEpisode( + domain=SerializationHelpers.inner_list_join(domains), + api_schemas_machine=self._get_intent_groundinging( + schema_lookup, set(dialogue["services"]) + ), + goal_calls_machine=self._get_all_service_calls(turns), + rounds=rounds, + delex=self.opt.get("delex"), + extras={"dialogue_id": dialogue["dialogue_id"]}, + ) + result.append(episode) + # check if the number of episodes should be limited and truncate as required + return result - user_turn = turns[turn_id] - sys_turn = turns[turn_id + 1] - # input: user_turn, output: sys_turn - yield { - 'text': user_turn['utterance'], - 'label': sys_turn['utterance'], - 'type': 'text', - }, is_first_turn + def get_id_task_prefix(self): + return "GoogleSGD" + + +class SystemTeacher(GoogleSGDParser, tod_agents.TodSystemTeacher): + pass + + +class DefaultTeacher(SystemTeacher): + pass + + +class UserSimulatorTeacher(GoogleSGDParser, tod_agents.TodUserSimulatorTeacher): + pass + + +class StandaloneApiTeacher(GoogleSGDParser, tod_agents.TodStandaloneApiTeacher): + pass + + +class SingleGoalAgent(GoogleSGDParser, tod_agents.TodSingleGoalAgent): + pass + + +class GoalAgent(GoogleSGDParser, tod_agents.TodGoalAgent): + pass + + +class ApiSchemaAgent(GoogleSGDParser, tod_agents.TodApiSchemaAgent): + pass + + +class UserUttAgent(GoogleSGDParser, tod_agents.TodUserUttAgent): + pass -class DefaultTeacher(Text2API2TextTeacher): +class ApiCallAndSysUttAgent(GoogleSGDParser, tod_agents.TodApiCallAndSysUttAgent): pass diff --git a/parlai/tasks/google_sgd/build.py b/parlai/tasks/google_sgd/build.py index d0432ed1241..d0bf47adce2 100644 --- a/parlai/tasks/google_sgd/build.py +++ b/parlai/tasks/google_sgd/build.py @@ -8,7 +8,7 @@ import parlai.core.build_data as build_data import os -ROOT_URL = 'https://github.com/google-research-datasets/dstc8-schema-guided-dialogue/raw/master' +ROOT_URL = "https://github.com/google-research-datasets/dstc8-schema-guided-dialogue/raw/master" DATA_LEN = {"train": 127, "dev": 20, "test": 34} @@ -18,13 +18,13 @@ def fold_size(fold): def build(opt): # get path to data directory - dpath = os.path.join(opt['datapath'], 'google_sgd') + dpath = os.path.join(opt["datapath"], "google_sgd") # define version if any version = "1.0" # check if data had been previously built if not build_data.built(dpath, version_string=version): - print('[building data: ' + dpath + ']') + print("[building data: " + dpath + "]") # make a clean directory if needed if build_data.built(dpath): @@ -33,16 +33,16 @@ def build(opt): build_data.make_dir(dpath) # Download the data. - for split_type in ['train', 'dev', 'test']: + for split_type in ["train", "dev", "test"]: outpath = os.path.join(dpath, split_type) - filename = 'schema.json' - url = f'{ROOT_URL}/{split_type}/{filename}' + filename = "schema.json" + url = f"{ROOT_URL}/{split_type}/{filename}" build_data.make_dir(outpath) build_data.download(url, outpath, filename) for file_id in range(1, DATA_LEN[split_type] + 1): - filename = f'dialogues_{file_id:03d}.json' - url = f'{ROOT_URL}/{split_type}/{filename}' + filename = f"dialogues_{file_id:03d}.json" + url = f"{ROOT_URL}/{split_type}/{filename}" build_data.download(url, outpath, filename) # mark the data as built diff --git a/parlai/tasks/google_sgd/test.py b/parlai/tasks/google_sgd/test.py index e348c087c61..5ea8078f74e 100644 --- a/parlai/tasks/google_sgd/test.py +++ b/parlai/tasks/google_sgd/test.py @@ -7,13 +7,9 @@ from parlai.utils.testing import AutoTeacherTest -class DisabledTestDefaultTeacher(AutoTeacherTest): +class TestDefaultTeacher(AutoTeacherTest): task = "google_sgd" -class DisabledTestText2API2TextTeacher(AutoTeacherTest): - task = "google_sgd:text2_a_p_i2_text" - - -class DisabledTestText2TextTeacher(AutoTeacherTest): - task = "google_sgd:text2_text" +class TestUserSimulatorTeacher(AutoTeacherTest): + task = "google_sgd:UserSimulatorTeacher" diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml new file mode 100644 index 00000000000..dad04075246 --- /dev/null +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml @@ -0,0 +1,51 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Hi, could you get me a restaurant booking on the 8th please?' + id: GoogleSGD_UserSimulatorTeacher + text: 'GOAL: api_name = ReserveRestaurant ; date = 2019-03-08 ; location = Corte + Madera ; number_of_seats = 2 ; restaurant_name = P.f. Chang''s ; time = 12:00 + | api_name = ReserveRestaurant ; date = 2019-03-08 ; location = Corte Madera + ; number_of_seats = 2 ; restaurant_name = Benissimo Restaurant & Bar ; time + = 12:00' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Could you get me a reservation at P.f. Chang''s in Corte Madera at afternoon + 12?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Any preference on the restaurant, location and time?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Sure, that is great.' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Please confirm your reservation at P.f. Chang''s in Corte Madera + at 12 pm for 2 on March 8th.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Could you try booking a table at Benissimo instead?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sorry, your reservation could not be made. Could I help you with + something else?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Sure, may I know if they have vegetarian options and how expensive is + their food?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Sure, please confirm your reservation at Benissimo Restaurant & + Bar in Corte Madera at 12 pm for 2 on March 8th.' + type: 'USER: ' +num_episodes: 4201 +num_examples: 97197 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml new file mode 100644 index 00000000000..fca91a3b09f --- /dev/null +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml @@ -0,0 +1,49 @@ +acts: +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: I am feeling hungry so I would like to find a place to eat.' + text: 'GOAL: api_name = FindRestaurants ; city = San Jose ; cuisine = American + | api_name = FindRestaurants ; city = Palo Alto ; cuisine = American ; price_range + = moderate | api_name = ReserveRestaurant ; city = Palo Alto ; date = 2019-03-01 + ; party_size = 2 ; restaurant_name = Bird Dog ; time = 11:30' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: I would like for it to be in San Jose.' + slots: {} + text: 'SYSTEM: Do you have a specific which you want the eating place to be located + at?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: I usually like eating the American type of food.' + slots: {} + text: 'SYSTEM: Is there a specific cuisine type you enjoy, such as Mexican, Italian + or something else?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: Can you give me the address of this restaurant.' + slots: {} + text: 'SYSTEM: I see that at 71 Saint Peter there is a good restaurant which is + in San Jose.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_UserSimulatorTeacher + labels: + - 'USER: Can you give me the phone number that I can contact them with?' + slots: {} + text: 'SYSTEM: If you want to go to this restaurant you can find it at 71 North + San Pedro Street.' + type: 'USER: ' +num_episodes: 16142 +num_examples: 378390 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml new file mode 100644 index 00000000000..e9a511d9519 --- /dev/null +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml @@ -0,0 +1,46 @@ +acts: +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: I want to make a restaurant reservation for 2 people at half past 11 + in the morning.' + id: GoogleSGD_UserSimulatorTeacher + text: 'GOAL: api_name = ReserveRestaurant ; date = 2019-03-01 ; location = San + Jose ; number_of_seats = 2 ; restaurant_name = Sino ; time = 11:30' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Please find restaurants in San Jose. Can you try Sino?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: What city do you want to dine in? Do you have a preferred restaurant?' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Yes, thanks. What''s their phone number?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Confirming: I will reserve a table for 2 people at Sino in San + Jose. The reservation time is 11:30 am today.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: What''s their address? Do they have vegetarian options on their menu?' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: Your reservation has been made. Their phone number is 408-247-8880.' + type: 'USER: ' +- - domain: Restaurants + episode_done: false + eval_labels: + - 'USER: Thanks very much.' + id: GoogleSGD_UserSimulatorTeacher + slots: {} + text: 'SYSTEM: The street address is 377 Santana Row #1000. They have good vegetarian + options.' + type: 'USER: ' +num_episodes: 2482 +num_examples: 56172 diff --git a/parlai/tasks/google_sgd/test/google_sgd_test.yml b/parlai/tasks/google_sgd/test/google_sgd_test.yml index 1d66bb9f0bb..4d8e4e2578e 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_test.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_test.yml @@ -1,77 +1,45 @@ acts: -- - episode_done: false +- - domain: Restaurants + episode_done: false eval_labels: - - Any preference on the restaurant, location and time? - id: google_sgd - slots: - location: [] - restaurant_name: [] - time: [] - text: 'Hi, could you get me a restaurant booking on the 8th please? api_resp: - time = [] ; restaurant_name = [] ; location = []' - type: apiresp -- - episode_done: false + - 'APIS: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Please confirm your reservation at P.f. Chang's in Corte Madera at 12 pm for - 2 on March 8th. - id: google_sgd - slots: - date: - - '2019-03-08' - location: - - Corte Madera - number_of_seats: - - '2' - restaurant_name: - - P.f. Chang's - time: - - '12:00' - text: 'Could you get me a reservation at P.f. Chang''s in Corte Madera at afternoon - 12? api_resp: restaurant_name = ["P.f. Chang''s"] ; location = [''Corte Madera''] - ; time = [''12:00''] ; date = [''2019-03-08''] ; number_of_seats = [''2'']' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: Hi, could you get me a restaurant booking on the 8th please?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'apicall: ReserveRestaurant.date = 2019-03-08 ; ReserveRestaurant.location = - Corte Madera ; ReserveRestaurant.number_of_seats = 2 ; ReserveRestaurant.restaurant_name - = P.f. Chang''s ; ReserveRestaurant.time = 12:00' - id: google_sgd - slots: - ReserveRestaurant.date: '2019-03-08' - ReserveRestaurant.location: Corte Madera - ReserveRestaurant.number_of_seats: '2' - ReserveRestaurant.restaurant_name: P.f. Chang's - ReserveRestaurant.time: '12:00' - text: Sure, that is great. - type: apicall -- - episode_done: false + - 'SYSTEM: Any preference on the restaurant, location and time?' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Sorry, your reservation could not be made. Could I help you with something else? - id: google_sgd - slots: - ? '' - : [] - text: 'api_resp: = []' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: Could you get me a reservation at P.f. Chang''s in Corte Madera at + afternoon 12?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Sure, please confirm your reservation at Benissimo Restaurant & Bar in Corte - Madera at 12 pm for 2 on March 8th. - id: google_sgd - slots: - date: - - '2019-03-08' - location: - - Corte Madera - number_of_seats: - - '2' - restaurant_name: - - Benissimo Restaurant & Bar - time: - - '12:00' - text: 'Could you try booking a table at Benissimo instead? api_resp: restaurant_name - = [''Benissimo Restaurant & Bar''] ; location = [''Corte Madera''] ; time = - [''12:00''] ; date = [''2019-03-08''] ; number_of_seats = [''2'']' - type: apiresp + - 'SYSTEM: Please confirm your reservation at P.f. Chang''s in Corte Madera at + 12 pm for 2 on March 8th.' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' num_episodes: 4201 -num_examples: 54970 +num_examples: 97197 diff --git a/parlai/tasks/google_sgd/test/google_sgd_train.yml b/parlai/tasks/google_sgd/test/google_sgd_train.yml index cd2b3a48e30..596025f82da 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_train.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_train.yml @@ -1,54 +1,45 @@ acts: -- - episode_done: false - id: google_sgd +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - Do you have a specific which you want the eating place to be located at? - slots: - city: [] - text: 'I am feeling hungry so I would like to find a place to eat. api_resp: city - = []' - type: apiresp -- - episode_done: false - id: google_sgd + - 'APIS: ' + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - Is there a specific cuisine type you enjoy, such as Mexican, Italian or something - else? - slots: - cuisine: - - Mexican - - Italian - text: 'I would like for it to be in San Jose. api_resp: cuisine = [''Mexican'', - ''Italian'']' - type: apiresp -- - episode_done: false - id: google_sgd + - 'APICALL: ' + slots: {} + text: 'USER: I am feeling hungry so I would like to find a place to eat.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - 'apicall: FindRestaurants.city = San Jose ; FindRestaurants.cuisine = American' - slots: - FindRestaurants.city: San Jose - FindRestaurants.cuisine: American - text: I usually like eating the American type of food. - type: apicall -- - episode_done: false - id: google_sgd + - 'SYSTEM: Do you have a specific which you want the eating place to be located + at?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - I see that at 71 Saint Peter there is a good restaurant which is in San Jose. - slots: - city: - - San Jose - restaurant_name: - - 71 Saint Peter - text: 'api_resp: restaurant_name = [''71 Saint Peter''] ; city = [''San Jose'']' - type: apiresp -- - episode_done: false - id: google_sgd + - 'APICALL: ' + slots: {} + text: 'USER: I would like for it to be in San Jose.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false + id: GoogleSGD_SystemTeacher labels: - - If you want to go to this restaurant you can find it at 71 North San Pedro Street. - slots: - street_address: - - 71 North San Pedro Street - text: 'Can you give me the address of this restaurant. api_resp: street_address - = [''71 North San Pedro Street'']' - type: apiresp + - 'SYSTEM: Is there a specific cuisine type you enjoy, such as Mexican, Italian + or something else?' + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' num_episodes: 16142 -num_examples: 215128 +num_examples: 378390 diff --git a/parlai/tasks/google_sgd/test/google_sgd_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_valid.yml index 8c95cc85c06..2cd09f89ce7 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_valid.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_valid.yml @@ -1,70 +1,45 @@ acts: -- - episode_done: false +- - domain: Restaurants + episode_done: false eval_labels: - - What city do you want to dine in? Do you have a preferred restaurant? - id: google_sgd - slots: - location: [] - restaurant_name: [] - text: 'I want to make a restaurant reservation for 2 people at half past 11 in - the morning. api_resp: restaurant_name = [] ; location = []' - type: apiresp -- - episode_done: false + - 'APIS: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIS: ' + type: 'APIS: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'Confirming: I will reserve a table for 2 people at Sino in San Jose. The reservation - time is 11:30 am today.' - id: google_sgd - slots: - date: - - '2019-03-01' - location: - - San Jose - number_of_seats: - - '2' - restaurant_name: - - Sino - time: - - '11:30' - text: 'Please find restaurants in San Jose. Can you try Sino? api_resp: restaurant_name - = [''Sino''] ; location = [''San Jose''] ; time = [''11:30''] ; number_of_seats - = [''2''] ; date = [''2019-03-01'']' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: I want to make a restaurant reservation for 2 people at half past + 11 in the morning.' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'apicall: ReserveRestaurant.date = 2019-03-01 ; ReserveRestaurant.location = - San Jose ; ReserveRestaurant.number_of_seats = 2 ; ReserveRestaurant.restaurant_name - = Sino ; ReserveRestaurant.time = 11:30' - id: google_sgd - slots: - ReserveRestaurant.date: '2019-03-01' - ReserveRestaurant.location: San Jose - ReserveRestaurant.number_of_seats: '2' - ReserveRestaurant.restaurant_name: Sino - ReserveRestaurant.time: '11:30' - text: Yes, thanks. What's their phone number? - type: apicall -- - episode_done: false + - 'SYSTEM: What city do you want to dine in? Do you have a preferred restaurant?' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' +- - domain: Restaurants + episode_done: false eval_labels: - - Your reservation has been made. Their phone number is 408-247-8880. - id: google_sgd - slots: - ? '' - : [] - phone_number: - - 408-247-8880 - text: 'api_resp: phone_number = [''408-247-8880''] ; = []' - type: apiresp -- - episode_done: false + - 'APICALL: ' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'USER: Please find restaurants in San Jose. Can you try Sino?' + type: 'APICALL: ' +- - domain: Restaurants + episode_done: false eval_labels: - - 'The street address is 377 Santana Row #1000. They have good vegetarian options.' - id: google_sgd - slots: - address: - - '377 Santana Row #1000' - has_vegetarian_options: - - 'True' - text: 'What''s their address? Do they have vegetarian options on their menu? api_resp: - has_vegetarian_options = [''True''] ; address = [''377 Santana Row #1000'']' - type: apiresp + - 'SYSTEM: Confirming: I will reserve a table for 2 people at Sino in San Jose. + The reservation time is 11:30 am today.' + id: GoogleSGD_SystemTeacher + slots: {} + text: 'APIRESP: ' + type: 'SYSTEM: ' num_episodes: 2482 -num_examples: 31825 +num_examples: 56172 From 3675781fb5c7f5f9adba8c84f997a063f6c123cb Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 14:55:47 -0800 Subject: [PATCH 13/23] use same version of black as in the pre-commit hook --- parlai/core/tod/tod_agents.py | 11 ++-------- parlai/core/tod/tod_test_utils/test_agents.py | 9 +-------- tests/tod/test_tod_agents_and_teachers.py | 20 ++++--------------- tests/tod/test_tod_teacher_metrics.py | 5 +---- 4 files changed, 8 insertions(+), 37 deletions(-) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index f22a8330760..ee06dd6766a 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -542,11 +542,7 @@ def _do_fetch(self, call_text): resp = self.data[call_text] else: resp = self.data.get(call_text, tod.STANDARD_RESP) - return { - "text": resp, - "id": self.id, - "episode_done": False, - } + return {"text": resp, "id": self.id, "episode_done": False} # Not exact case best_key = difflib.get_close_matches(call_text, self.data.keys(), 1) @@ -674,10 +670,7 @@ def custom_evaluation( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) metrics = NlgMetrics( - guess=resp, - labels=labels, - prefixes=domains, - avg_jga_nlg_bleu=True, + guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True ).report() for key, value in metrics.items(): self.metrics.add(key, value) diff --git a/parlai/core/tod/tod_test_utils/test_agents.py b/parlai/core/tod/tod_test_utils/test_agents.py index b7639efea36..8d749fa40c7 100644 --- a/parlai/core/tod/tod_test_utils/test_agents.py +++ b/parlai/core/tod/tod_test_utils/test_agents.py @@ -91,14 +91,7 @@ def get_round_utts(episode_idx, max_rounds, filter_utts=None): f"SYSTEM: sys_utt_{episode_idx}_{i}", ] ) - utts.append( - [ - "USER: [DONE]", - "APICALL: ", - "APIRESP: ", - "SYSTEM: ", - ] - ) + utts.append(["USER: [DONE]", "APICALL: ", "APIRESP: ", "SYSTEM: "]) if filter_utts is not None: utts = [ [turn for i, turn in enumerate(round_data) if filter_utts[i]] diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 5383c72416d..2914cb927b3 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -247,10 +247,7 @@ def helper(n_shot): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, - { - "episodes_randomization_seed": 0, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": 0, "n_shot": n_shot}, ) self.assertEqual(len(values), n_shot) @@ -269,10 +266,7 @@ def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE, - { - "episodes_randomization_seed": seed, - "n_shot": n_shot, - }, + {"episodes_randomization_seed": seed, "n_shot": n_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.FEW_SHOT_SAMPLES] @@ -286,10 +280,7 @@ def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": 0, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": 0, "percent_shot": percent_shot}, ) self.assertEqual(len(values), correct) @@ -302,10 +293,7 @@ def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__MULTI_EPISODE_BS, # 35 episodes - { - "episodes_randomization_seed": seed, - "percent_shot": percent_shot, - }, + {"episodes_randomization_seed": seed, "percent_shot": percent_shot}, ) data_dumps_seed_zero = [helper(i, 0) for i in self.PERCENTAGES] diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index aec4aba40f6..419210fa18a 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -62,10 +62,7 @@ def test_base_slot_metrics(self): ), ] for teacher, predicted, result in cases: - metric = SlotMetrics( - teacher_slots=teacher, - predicted_slots=predicted, - ) + metric = SlotMetrics(teacher_slots=teacher, predicted_slots=predicted) for key in result: self.assertEqual(result[key], metric.report()[key]) From 0bc961e66e56250703b27cccb8e9c5f2e6ac2528 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 14:56:22 -0800 Subject: [PATCH 14/23] use same version of black as in the pre-commit hook --- parlai/core/tod/world_metrics.py | 7 ++----- parlai/core/tod/world_metrics_handlers.py | 9 ++------- parlai/scripts/distributed_tod_world_script.py | 4 +--- parlai/scripts/tod_world_script.py | 8 ++------ tests/tod/test_tod_world_metrics.py | 8 ++------ tests/tod/test_tod_world_metrics_in_script.py | 11 +++++------ tests/tod/test_tod_world_script_metrics.py | 4 +--- 7 files changed, 15 insertions(+), 36 deletions(-) diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py index 4e8ba4555e4..17d96b8698f 100644 --- a/parlai/core/tod/world_metrics.py +++ b/parlai/core/tod/world_metrics.py @@ -10,9 +10,7 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metrics, -) +from parlai.core.metrics import Metrics from parlai.core.tod.tod_core import ( TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, @@ -73,8 +71,7 @@ def _handle_message_impl( ) if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: return handler.handle_api_schemas( - message, - SerializationHelpers.str_to_api_schemas(prefix_stripped_text), + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) ) if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: return handler.handle_goals( diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py index 3f4b68477fe..ffa7343f114 100644 --- a/parlai/core/tod/world_metrics_handlers.py +++ b/parlai/core/tod/world_metrics_handlers.py @@ -10,13 +10,8 @@ """ from parlai.core.message import Message -from parlai.core.metrics import ( - Metric, - AverageMetric, -) -from parlai.core.tod.tod_core import ( - STANDARD_DONE, -) +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE from typing import Dict, List, Optional METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py index 87c8333ec41..8dc36a6feb7 100644 --- a/parlai/scripts/distributed_tod_world_script.py +++ b/parlai/scripts/distributed_tod_world_script.py @@ -9,9 +9,7 @@ Not to be called directly; should be called from SLURM """ -from parlai.scripts.tod_world_script import ( - TodWorldScript, -) +from parlai.scripts.tod_world_script import TodWorldScript from parlai.core.script import ParlaiScript import parlai.utils.distributed as distributed_utils diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index 42199bc4ff3..c360882e61c 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -168,9 +168,7 @@ def setup_tod_args(cls, parser: ParlaiParser): def setup_args(cls): # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line parser = TodWorldParser( - True, - False, - "World for chatting with the TOD conversation structure", + True, False, "World for chatting with the TOD conversation structure" ) # Following params are same as the `eval_model` script parser.add_argument( @@ -261,9 +259,7 @@ def _get_tod_agents(self, opt: Opt): ).replace("Goal", "ApiSchema") agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( - opt, - "api_schema_grounding_model", - tod_world_agents.EmptyApiSchemaAgent, + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent ) return agents diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 0367f655994..1399050b3ca 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -19,12 +19,8 @@ TodAgentType, TOD_AGENT_TYPE_TO_PREFIX, ) -from parlai.core.tod.world_metrics import ( - TodMetrics, -) -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY # Ignore lint on following line; want to have registered classes show up for tests import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index ffac2a25e12..459458af9d2 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -15,9 +15,7 @@ from parlai.core.opt import Opt from parlai.core.tod.tod_core import SerializationHelpers import parlai.core.tod.tod_test_utils.test_agents as test_agents -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY import parlai.scripts.tod_world_script as tod_world_script # Ignore lint on following line; want to have registered classes show up for tests @@ -233,9 +231,10 @@ def get_episode_report(goal, episode_metric): metrics_dict["goal"] = goal return metrics_dict - return dict_report(script.world.report()), [ - get_episode_report(g, e) for g, e in script.episode_metrics - ] + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) def test_apiCallAttempts_usingGold(self): opt = copy.deepcopy(TEST_SETUP) diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 9862b3c824f..44f060acb84 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -14,9 +14,7 @@ import parlai.core.tod.tod_test_utils.test_agents as test_agents import parlai.scripts.tod_world_script as tod_world_script from parlai.core.tod.tod_agents import StandaloneApiAgent -from parlai.core.tod.world_metrics_handlers import ( - METRICS_HANDLER_CLASSES_TEST_REGISTRY, -) +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY from parlai.core.metrics import dict_report # Ignore lint on following line; want to have registered classes show up for tests From 24ee8984e6b655b52c849579e28efa7f498afee4 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 15:06:19 -0800 Subject: [PATCH 15/23] black with version from pre-commit hook --- parlai/tasks/tod_json/agents.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/parlai/tasks/tod_json/agents.py b/parlai/tasks/tod_json/agents.py index 55e3514de6c..3e994befacf 100644 --- a/parlai/tasks/tod_json/agents.py +++ b/parlai/tasks/tod_json/agents.py @@ -73,10 +73,7 @@ def add_cmdline_args( help="Use all data or split into 8:1:1 fold", ) agent.add_argument( - "--split-folds-seed", - type=int, - default=42, - help="Seed for the fold split", + "--split-folds-seed", type=int, default=42, help="Seed for the fold split" ) return parser From 3145e0edcc09123b8e757fc1bd150f31bc039fd2 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 16 Nov 2021 15:11:13 -0800 Subject: [PATCH 16/23] Shouldn't worry about tod_json being in task_list --- tests/test_zootasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_zootasks.py b/tests/test_zootasks.py index e5dc7023560..57ac70d90b5 100644 --- a/tests/test_zootasks.py +++ b/tests/test_zootasks.py @@ -76,7 +76,7 @@ def test_tasklist(self): task_list, "parlai/tasks", "task", - ignore=['fromfile', 'interactive', 'jsonfile', 'wrapper'], + ignore=['fromfile', 'interactive', 'jsonfile', 'tod_json', 'wrapper'], ) @pytest.mark.nofbcode From 2f1544806dbf6daf2ea60c0bf196352cce3e3e4c Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 08:45:59 -0800 Subject: [PATCH 17/23] address eric comments; add new readme + more documentation --- parlai/core/tod/README.md | 81 +++++++++++++++++++++++ parlai/core/tod/teacher_metrics.py | 11 --- parlai/core/tod/tod_agents.py | 55 ++------------- tests/tod/test_tod_agents_and_teachers.py | 65 +++++++++++++++++- tests/tod/test_tod_teacher_metrics.py | 4 ++ 5 files changed, 154 insertions(+), 62 deletions(-) create mode 100644 parlai/core/tod/README.md diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md new file mode 100644 index 00000000000..d2ffbe33570 --- /dev/null +++ b/parlai/core/tod/README.md @@ -0,0 +1,81 @@ +# Tod Core README + +For the quickest getting-to-use of core classes, start with the "Teachers + Agents Usage" section below (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format). + +See `projects/tod_simulator/README` for a higher-level usage-focused README. This document also describes the structure of the contents of this directory. + +As a convention, files referenced externally to this directory are prefixed with `tod` whereas those only referenced by other files within the directory are not. + +--- + +# Teachers + Agents Usage + +See `tod_agents.py` for the classes. + +For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. + +For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. + +Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for generating the simulations themselves. + +## Why we do this +These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. + +# Description of Agents + Teachers useful for Simulation +## Teachers for training (generative) models + * TodSystemTeacher + * TodUserSimulatorTeacher + +## Agents for Grounding +For goal grounding for the User for simulation: + * TodGoalAgent + * Dumps goals as is from the dataset, possibly multiple per episode + * TodSingleGoalAgent + * Flattens goals such that a single one is used to seed a conversation. For datasets that include multiple goals per conversation, each individual goal is used as a seed. + +For (optional) API schema grounding for the System: + * TodApiSchemaAgent (must be used with `TodGoalAgent` only) + * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) + * EmptyApiSchemaAgent + * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. + +## Agents for mocking APIs: + * StandaloneApiAgent + * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher`. (See comments in-line on classes for train command example) + +# Agents for dumping data from a ground truth dataset +The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. + * TodApiCallAndSysUttAgent + * TodApiResponseAgent + * TodUserUttAgent + +For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. + +# Other agents +There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. + +--- + +# Directory contents + +This directory is split into 3 main components: files to support agents + teachers, files to support the simulation world, and files to store functionality common to both of these. We describe the common functionality first then go to the other two. + +Tests for all files in this directory are stored in `tests/tod` + +## Files for common functionality +`tod_core.py` defines consts and enums used across TOD agents, teachers, and world. It also defines dataclasses for storing the intermediate data format used when parsing a dataset to the TOD structure as well as a `SerializationHelper` from going from machine structured data (ex. API Calls) to flattened versions used by the models. + + +## Files for agents and teachers +Usage of `tod_agents.py` is described above. It references `teacher_metrics.py` which stores Metrics objects. + +## Files for simulation world +Description of usage of the simulation world is primarily stored in the script running the world, stored in `parlai/scripts/tod_world_script.py`. The script is responsible for running multiple episodes of simulation and saving simulation output data. + +The world itself is stored in `tod_world.py`. The world follows the same intermediate dataformats for episodes as described in `tod_core.py` and does the correct calling of different agents to support this. It is generally recommended that this file not be touched. + +A general class for collecting metrics out of `TODWorld` is stored within `world_metrics.py` with individual 'metric handlers' responsible for calculating a given metric stored in `world_metric_handlers.py`. + + + + diff --git a/parlai/core/tod/teacher_metrics.py b/parlai/core/tod/teacher_metrics.py index 3fc85c7d107..83c140a3ea5 100644 --- a/parlai/core/tod/teacher_metrics.py +++ b/parlai/core/tod/teacher_metrics.py @@ -18,10 +18,6 @@ class SlotMetrics(Metrics): Due to differences in dialogue representations between tasks, the input is pre- parsed ground truth and predicted slot dictionaries. - - The 'jga+nlg' metric assumes a balanced set of JGA and NLG scores such that - 2 * Avg(JGA, NLG_BLEU) = Avg(JGA + NLG_BLEU) - The `jga+nlg` metric assumes that `NlgMetrics` is used to calculated the other side. """ def __init__( @@ -30,7 +26,6 @@ def __init__( predicted_slots: Dict[str, str], prefixes: Optional[List] = None, shared: Dict[str, Any] = None, - avg_jga_nlg_bleu: bool = False, ) -> None: super().__init__(shared=shared) self.prefixes = prefixes if prefixes else [] @@ -45,9 +40,6 @@ def __init__( "jga_empty", AverageMetric(teacher_slots == predicted_slots) ) - if avg_jga_nlg_bleu: - # add one half of Avg(jga,nlg_bleu), NlgMetrics class (below) adds NLG-BLEU - self.add("jga+nlg", AverageMetric(teacher_slots == predicted_slots)) # precision for pred_slot_name, pred_value in predicted_slots.items(): slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value) @@ -86,9 +78,6 @@ def __init__( f1 = F1Metric.compute(guess, labels) self.add_with_prefixes("nlg_bleu", bleu) self.add_with_prefixes("nlg_f1", f1) - if avg_jga_nlg_bleu: - # add one half of Avg(jga,nlg_bleu), SlotMetrics class (above) adds JGA - self.add("jga+nlg", bleu) def add_with_prefixes(self, name, value): self.add(name, value) diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index ee06dd6766a..80d9012a1f2 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -7,49 +7,9 @@ Agents (used for dumping data) and Teachers (for training models) related to the TOD conversation setup. -# Usage - -For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. - -For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model. - -Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for usage in generating simulations. - -As a convention, agents and teachers that are inheritable are prefixed with "Tod" whereas those that can be used as-is are not. Similarly, classes and functions that do not need to be exposed outside of this file are prefixed with a single underscore ('_'). - -## Why we do this -These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest. - -# Description of Agents + Teachers useful for Simulation -## Teachers for training (generative) models - * TodSystemTeacher - * TodUserSimulatorTeacher - -## Agents for Grounding -For goal grounding for the User for simulation: - * TodGoalAgent - * TodSingleGoalAgent - -For (optional) API schema grounding for the System: - * TodApiSchemaAgent (must be used with `TodGoalAgent` only) - * TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only) - * EmptyApiSchemaAgent - * Used for simulations where the expectation is `no schema`, ie, evaluation simulations. - -## Agents for mocking APIs: - * StandaloneApiAgent - * Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher` - -# Agents for dumping data from a ground truth dataset -The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines. - * TodApiCallAndSysUttAgent - * TodApiResponseAgent - * TodUserUttAgent - -For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used. - -# Other agents -There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary. +As a convention, agents and teachers that are inheritable are prefixed with "Tod" +whereas those that can be used as-is are not. Similarly, classes and functions that do +not need to be exposed outside of this file are prefixed with a single underscore ('_') """ from parlai.core.agents import Agent @@ -451,7 +411,7 @@ class StandaloneApiAgent(Agent): Use `TodStandaloneApiTeacher` to train this class. For example for a MultiWoz V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m - parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the + parlai.core.tod.tod_agents:StandaloneApiAgent -eps 4 -mf output ``` to generate the `.pickle` file to use. """ @@ -647,7 +607,6 @@ def custom_evaluation( metrics = SlotMetrics( teacher_slots=teacher_action["slots"], predicted_slots=predicted, - avg_jga_nlg_bleu=True, prefixes=domains, ).report() for key, value in metrics.items(): @@ -669,9 +628,7 @@ def custom_evaluation( domains = ( [teacher_action["domain"]] if self.opt["domain_nlg_record"] else [] ) - metrics = NlgMetrics( - guess=resp, labels=labels, prefixes=domains, avg_jga_nlg_bleu=True - ).report() + metrics = NlgMetrics(guess=resp, labels=labels, prefixes=domains).report() for key, value in metrics.items(): self.metrics.add(key, value) @@ -765,7 +722,7 @@ class TodStandaloneApiTeacher(TodStructuredDataParser, DialogTeacher): Set this as the teacher with `StandaloneApiAgent` as the agent. Ex for a MultiWoz V2.2 standalone API, use ``` parlai train -t multiwoz_v22:StandaloneApiTeacher -m - parlai_fb.agents.tod.agents:StandaloneApiAgent -eps 4 -mf output ``` + parlai.core.tod.tod_agents:StandaloneApiAgent -eps 4 -mf output ``` """ def setup_data(self, fold): diff --git a/tests/tod/test_tod_agents_and_teachers.py b/tests/tod/test_tod_agents_and_teachers.py index 2914cb927b3..991d165f278 100644 --- a/tests/tod/test_tod_agents_and_teachers.py +++ b/tests/tod/test_tod_agents_and_teachers.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -Tests different (more complicated) slot metrics. +Tests teachers + agent implementations, assuming parser to conversations format has +already been done and teachers/agents already created. + +`test_agents.py` includes functions for generating the raw data used in this file as +well as the data parser. """ import unittest @@ -16,6 +20,10 @@ class TestTodAgentsAndTeachersBase(unittest.TestCase): + """ + Base class with convenience functions for setting up agents, dumping text, etc. + """ + def setup_agent_or_teacher(self, class_type, round_opt, opt): full_opts = {**round_opt, **opt} full_opts["datatype"] = "DUMMY" @@ -24,6 +32,9 @@ def setup_agent_or_teacher(self, class_type, round_opt, opt): return class_type(full_opts) def dump_single_utt_per_episode_agent_text(self, class_type, round_opt, opt): + """ + Continuously dumps data from an agent until it's done. + """ agent = self.setup_agent_or_teacher(class_type, round_opt, opt) result = [] while not agent.epoch_done(): @@ -48,14 +59,30 @@ def dump_teacher_text(self, class_type, round_opt, opt): return data def _test_roundDataCorrect(self): + """ + Convenience function that runs on different episode setups. + + Prefix with `_` since not all tests necessarily need this + """ self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__UTTERANCES_ONLY) self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + def _test_roundDataCorrect_helper(self, config): + """ + Implement this in downstream classes to define what is "correct" for a round (Ie + checking serialization data for a given class vs only checking utterances) + """ + raise RuntimeError("Not implemented") + class TestSystemTeacher(TestTodAgentsAndTeachersBase): def test_apiSchemas_with_yesApiSchemas(self): + """ + Tests to make sure that data from first turn is correct when we include API + Schemas. + """ values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__SINGLE_API_CALL, @@ -70,6 +97,10 @@ def test_apiSchemas_with_yesApiSchemas(self): ) def test_apiSchemas_with_noApiSchemas(self): + """ + Tests to make sure that data from first turn is correct when we do not include + API Schemas. + """ values = self.dump_teacher_text( test_agents.SystemTeacher, test_agents.EPISODE_SETUP__SINGLE_API_CALL, @@ -86,7 +117,7 @@ def _test_roundDataCorrect_helper(self, config): for utt in utts: comp.append([utt[0], utt[1]]) comp.append([utt[2], utt[3]]) - # Skip context turn cause we check it above + # Skip grounding turn cause we check it in the other teachers self.assertEqual(episode[1:], comp) def test_roundDataCorrect(self): @@ -95,6 +126,10 @@ def test_roundDataCorrect(self): class TestUserTeacher(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure that all of the User teacher data is correct relative to ground truth, + including grounding turn. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] values = self.dump_teacher_text(test_agents.UserSimulatorTeacher, config, {}) for episode_idx, episode in enumerate(values): @@ -121,6 +156,9 @@ def test_roundDataCorrect(self): class TestGoalAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure goal agent data is correct with (possibly) multiple goals. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -143,6 +181,9 @@ def test_roundDataCorrect(self): class TestApiSchemaAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure api schema information is correct with (possibly) multiple goals. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -165,6 +206,10 @@ def test_roundDataCorrect(self): class TestSingleGoalAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure single goal agent correctly splits conversations with multiple goals + into single goals for the agent. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -187,6 +232,10 @@ def test_roundDataCorrect(self): class TestSingleApiSchemaAgent(TestTodAgentsAndTeachersBase): def _test_roundDataCorrect_helper(self, config): + """ + Make sure single api schema agent correctly splits conversations with multiple + goals into single goals for the agent. + """ max_rounds = config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] max_episodes = config[test_agents.TEST_NUM_EPISODES_OPT_KEY] values = self.dump_single_utt_per_episode_agent_text( @@ -262,6 +311,10 @@ def _test_subsets(self, data_dumps): self.assertEqual(episode, larger[i]) def test_few_shot_subset(self): + """ + Make sure specifying few-shot by n-shot works correctly. + """ + def helper(n_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, @@ -276,6 +329,10 @@ def helper(n_shot, seed): self.assertNotEqual(data_dumps_seed_zero[-1], data_dumps_seed_three[-1]) def test_percent_shot_lengths_correct(self): + """ + Make sure specifying few-shot by percentages works correctly. + """ + def helper(percent_shot, correct): values = self.dump_teacher_text( test_agents.SystemTeacher, @@ -289,6 +346,10 @@ def helper(percent_shot, correct): helper(0.3, 10) def test_percent_shot_subset(self): + """ + Make sure specifying few-shot by percentages works correctly. + """ + def helper(percent_shot, seed): return self.dump_teacher_text( test_agents.SystemTeacher, diff --git a/tests/tod/test_tod_teacher_metrics.py b/tests/tod/test_tod_teacher_metrics.py index 419210fa18a..60bbb68c5d9 100644 --- a/tests/tod/test_tod_teacher_metrics.py +++ b/tests/tod/test_tod_teacher_metrics.py @@ -4,6 +4,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +""" +File that includes tests for teacher metrics. +""" + import unittest from math import isnan From 5d0197df5276aa7f6213e782328330ebb9ca97a1 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 09:06:14 -0800 Subject: [PATCH 18/23] minor wording change --- parlai/core/tod/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md index d2ffbe33570..22e337f0b0f 100644 --- a/parlai/core/tod/README.md +++ b/parlai/core/tod/README.md @@ -1,6 +1,6 @@ -# Tod Core README +# Task-Oriented Dialog (TOD) Core README -For the quickest getting-to-use of core classes, start with the "Teachers + Agents Usage" section below (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format). +For the quickest getting-to-use of TOD classes, start with the "Teachers + Agents Usage" section below (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format). See `projects/tod_simulator/README` for a higher-level usage-focused README. This document also describes the structure of the contents of this directory. From 76bfa898f53badd0b072efe08ffc269a53ec545a Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 09:34:13 -0800 Subject: [PATCH 19/23] add more documtnation to world tests (following comment on teacher tests) --- tests/tod/test_tod_world_and_script.py | 36 ++++++++++++++++--- tests/tod/test_tod_world_metrics.py | 7 +++- tests/tod/test_tod_world_metrics_in_script.py | 25 ++++++++++++- tests/tod/test_tod_world_script_metrics.py | 9 ++--- 4 files changed, 66 insertions(+), 11 deletions(-) diff --git a/tests/tod/test_tod_world_and_script.py b/tests/tod/test_tod_world_and_script.py index d3f764bc13b..2d8e8944024 100644 --- a/tests/tod/test_tod_world_and_script.py +++ b/tests/tod/test_tod_world_and_script.py @@ -5,7 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Tests tod world, notably for batching. +Tests tod world + script, notably for batching, by comparing saved script logs to the +data that should have been generated. + +Metrics are handled in separate files. """ import copy @@ -114,6 +117,10 @@ def _check_correctness_from_script_logs( class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + """ + Checks that saved data is correct with a single batch. + """ + def _test_roundDataCorrect_helper(self, config): config["batchsize"] = 1 config["max_turns"] = 10 @@ -135,7 +142,7 @@ def _test_max_turn_helper(self, max_turns): config["batchsize"] = 1 config["max_turns"] = max_turns config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 - config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 5 # cause why not agents, opt = self.setup_agents(config) script = TestTodWorldScript(opt) script.agents = agents @@ -150,6 +157,10 @@ def filter_round_utt(utts): class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + """ + Checks saved data is correct with larger batchsizes. + """ + def _test_roundDataCorrect_helper(self, config): config["batchsize"] = 4 config["max_turns"] = 10 @@ -164,6 +175,13 @@ def test_roundDataCorrect(self): class TodWorldTestSingleDumpAgents(TodWorldInScriptTestBase): + """ + Just to be safe, make sure that the agents with "single" versions (ex goal + api + schema) are correctly aligned. + + (Already tested in the agents test file as well, but to be safe.) + """ + def setup_agents(self, added_opts, api_agent, goal_agent): full_opts = self.add_tod_world_opts(added_opts) full_opts["fixed_response"] = "USER: [DONE]" @@ -178,11 +196,11 @@ def setup_agents(self, added_opts, api_agent, goal_agent): ] return agents, full_opts - def test_SingleGoalApiResp_noBatching(self): + def _test_SingleGoalApiResp_helper(self, batchsize, num_episodes): config = {} - config["batchsize"] = 1 + config["batchsize"] = batchsize config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 - config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 2 # cause why not + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = num_episodes single_agents, opt = self.setup_agents( config, test_agents.SingleApiSchemaAgent, test_agents.SingleGoalAgent ) @@ -220,6 +238,14 @@ def test_SingleGoalApiResp_noBatching(self): single_idx += 1 + def test_SingleGoalApiResp_helper_singleBatch(self): + self._test_SingleGoalApiResp_helper(1, 2) + self._test_SingleGoalApiResp_helper(1, 5) + + def test_SingleGoalApiResp_helper_multiBatch(self): + self._test_SingleGoalApiResp_helper(4, 8) + self._test_SingleGoalApiResp_helper(4, 11) + if __name__ == "__main__": unittest.main() diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py index 1399050b3ca..293c677834f 100644 --- a/tests/tod/test_tod_world_metrics.py +++ b/tests/tod/test_tod_world_metrics.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Tests different (more complicated) slot metrics. +Test world metrics + world metrics handlers against dummy conversations. """ import unittest @@ -83,6 +83,11 @@ class TodMetricsTestHelper: + """ + Given a synthetic intermediate converesation, calculates the metrics for said + conversation. + """ + def __init__(self, e: TodStructuredEpisode): self.m = TodMetrics() self.m.handlers = [ diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py index 459458af9d2..19cc0cd20af 100644 --- a/tests/tod/test_tod_world_metrics_in_script.py +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -sTests tod world, notably for batching. +Tests tod world metrics in the full script, *including* making the script properly set +up the agents on its own. + +Use a few of the API Call + goal hit metrics as the metric handlers to test proper +functionality. """ import copy @@ -85,6 +89,9 @@ def _save_outputs(self, opt, world, logger, episode_metrics): class TodMetricsInScriptTests(unittest.TestCase): def test_all_goals_hit_all_success(self): + """ + For a setup where all the goals should be successfully hit, is it? + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP, batchsize=1, num_episodes=1, target_all_goals_hit=1 ) @@ -106,6 +113,9 @@ def test_all_goals_hit_all_success(self): ) def test_all_goals_hit_all_fail(self): + """ + For a setup where all the goals should *not* be successfully hit, do they fail? + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP_BROKEN_USER_SYSTEM, batchsize=1, @@ -139,6 +149,12 @@ def test_all_goals_hit_all_fail(self): ) def test_all_goals_hit_all_success_emptySchema(self): + """ + Check to make sure empty API schema doesn't have any impact on goal (Necessary + cause original, more exhaustive implementation of goal success would separate + between required + optional opts using the schema; make sure it doesn't impact + anything broader) + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP_EMPTY_APISCHEMA, batchsize=1, @@ -172,6 +188,13 @@ def test_all_goals_hit_all_success_emptySchema(self): ) def test_all_goals_hit_all_fail_emptySchema(self): + """ + Make sure empty schema has no impact on goal success. + + (Necessary cause original, more exhaustive implementation of goal success would + separate between required + optional opts using the schema; make sure it doesn't + impact anything broader) + """ self._check_all_goals_hit_by_opt_and_batchsize( TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, batchsize=1, diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py index 44f060acb84..5152a7bfa23 100644 --- a/tests/tod/test_tod_world_script_metrics.py +++ b/tests/tod/test_tod_world_script_metrics.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. """ -Tests tod world, notably for batching. +Tests tod world metrics in the full script, *without* making the script properly set up +the agents on its own. + +Use a few of the API Call + goal hit metrics as the metric handlers to test proper +functionality. """ import copy @@ -104,9 +108,6 @@ def _check_metrics_correct(self, script, opt): max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] episode_metrics = script.episode_metrics for episode_idx, episode in enumerate(episode_metrics): - # if episode_idx >= max_episodes: - # break - # See how we make broken mock api calls in the test_agents. goal, episode_metric = episode episode_metric = dict_report(episode_metric.report()) self.assertAlmostEqual( From 73c5c7a58f15eac6fe27320ed565c453e2d04e60 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 30 Nov 2021 09:45:00 -0800 Subject: [PATCH 20/23] minor comment update --- parlai/scripts/tod_world_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py index c360882e61c..e2d50a1f445 100644 --- a/parlai/scripts/tod_world_script.py +++ b/parlai/scripts/tod_world_script.py @@ -6,7 +6,7 @@ """ Base script for running TOD model-model chats. -For example, to extract gold ground truth data from Google SGD, run +For example, to extract gold ground truth data from the holdout version of Google SGD, run ``` python -u -m parlai.scripts.tod_world_script --api-schema-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent --goal-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent --user-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent --system-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent --api-resp-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent -dt valid --num-episodes -1 --episodes-randomization-seed 42 --world-logs gold-valid From 7ab9d70476be5831aca05322ab1d2954aac0f698 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 12:23:53 -0800 Subject: [PATCH 21/23] update to respect actual count of episodes (I think this might have implications for mutators too?) --- parlai/core/teachers.py | 2 ++ parlai/core/tod/tod_agents.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/parlai/core/teachers.py b/parlai/core/teachers.py index 2b3c1d71741..1f8829324db 100644 --- a/parlai/core/teachers.py +++ b/parlai/core/teachers.py @@ -726,6 +726,8 @@ def num_episodes(self) -> int: """ Return the number of episodes in the data. """ + if hasattr(self, "_num_episodes_cache"): + return self._num_episodes_cache try: return self.data.num_episodes() except AttributeError: diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index 80d9012a1f2..0dd26cd270e 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -590,6 +590,11 @@ def add_cmdline_args( ) return parser + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + self._num_examples_cache = sum([len(x.rounds) * 2 for x in self.episodes]) + self._num_episodes_cache = len(self.episodes) + def custom_evaluation( self, teacher_action: Message, labels, model_response: Message ): @@ -671,6 +676,12 @@ class TodUserSimulatorTeacher(TodStructuredDataParser, DialogTeacher): Utterance->User Utterance` for all subsequent turns. """ + def __init__(self, opt, shared=None): + super().__init__(opt, shared) + # Manually set number of examples + number of episodes + self._num_examples_cache = sum([len(x.rounds) for x in self.episodes]) + self._num_episodes_cache = len(self.episodes) + def setup_data(self, fold): for episode in self.generate_episodes(): if len(episode.rounds) < 1: From 94661448616a94dcfbe5f659eeaebbd7f6000c34 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 1 Dec 2021 22:30:53 -0800 Subject: [PATCH 22/23] regen after changing tod teacher logic to respect episode/examples length --- .../google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml | 2 +- .../google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml | 2 +- .../google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_test.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_train.yml | 2 +- parlai/tasks/google_sgd/test/google_sgd_valid.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml index dad04075246..8ea369e9cad 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_test.yml @@ -48,4 +48,4 @@ acts: Bar in Corte Madera at 12 pm for 2 on March 8th.' type: 'USER: ' num_episodes: 4201 -num_examples: 97197 +num_examples: 46498 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml index fca91a3b09f..d4ea92d7d1b 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_train.yml @@ -46,4 +46,4 @@ acts: San Pedro Street.' type: 'USER: ' num_episodes: 16142 -num_examples: 378390 +num_examples: 181124 diff --git a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml index e9a511d9519..59d81a949c9 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_UserSimulatorTeacher_valid.yml @@ -43,4 +43,4 @@ acts: options.' type: 'USER: ' num_episodes: 2482 -num_examples: 56172 +num_examples: 26845 diff --git a/parlai/tasks/google_sgd/test/google_sgd_test.yml b/parlai/tasks/google_sgd/test/google_sgd_test.yml index 4d8e4e2578e..d55849a814b 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_test.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_test.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 4201 -num_examples: 97197 +num_examples: 92996 diff --git a/parlai/tasks/google_sgd/test/google_sgd_train.yml b/parlai/tasks/google_sgd/test/google_sgd_train.yml index 596025f82da..584b6a56561 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_train.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_train.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 16142 -num_examples: 378390 +num_examples: 362248 diff --git a/parlai/tasks/google_sgd/test/google_sgd_valid.yml b/parlai/tasks/google_sgd/test/google_sgd_valid.yml index 2cd09f89ce7..3d58a2b919d 100644 --- a/parlai/tasks/google_sgd/test/google_sgd_valid.yml +++ b/parlai/tasks/google_sgd/test/google_sgd_valid.yml @@ -42,4 +42,4 @@ acts: text: 'APIRESP: ' type: 'SYSTEM: ' num_episodes: 2482 -num_examples: 56172 +num_examples: 53690 From acd6ffe40aa8c1da18a310a43e81214ac4d48058 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Wed, 22 Dec 2021 09:49:14 -0800 Subject: [PATCH 23/23] not sure why this comment keeps not being merged correctly ugh... --- parlai/core/tod/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md index 11abaa1ad92..f29e5d2778a 100644 --- a/parlai/core/tod/README.md +++ b/parlai/core/tod/README.md @@ -10,7 +10,7 @@ As a convention, files referenced externally to this directory are prefixed with # Teachers + Agents Usage -tl;dr Extend `TodStructuredDataParser` for your particular dataset and implement `generate_episodes()` that converts the dataset into a list of episodes (`List[TodStructuredEpisode]`). Use multiple inheritence to generate teachers for training models. See files like `parlai/tasks/multiwoz_v22/agents.py` for an example. +tl;dr Extend `TodStructuredDataParser` for your particular dataset and implement `setup_episodes()` that converts the dataset into a list of episodes (`List[TodStructuredEpisode]`). Use multiple inheritence to generate teachers for training models. See files like `parlai/tasks/multiwoz_v22/agents.py` for an example. See `tod_agents.py` for the classes.