Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TOD][Datasets][Easy] MetalWoz into ParlAI (User + System utterances) #4183

Merged
merged 73 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
e365e48
[TOD] Core converesation structure, serialization, const tokens
Nov 15, 2021
c939174
[Tod] Agents, teacher metrics, and tests for these
Nov 16, 2021
3bf655f
[TOD] Tod json structure to teacher task
Nov 16, 2021
6cb4b86
[TOD] Core converesation structure, serialization, const tokens
Nov 15, 2021
1480def
fix test by adding init folder
Nov 16, 2021
de84801
[Tod] Agents, teacher metrics, and tests for these
Nov 16, 2021
638eb28
[TOD] World, world metrics, script, tests
Nov 16, 2021
0e3f492
hmmm... hoping stacks don't bite me. (change that was kept in upper d…
Nov 16, 2021
0643a62
Merge branch 'simpler_tod_1_core_only' into simpler_tod_2_agents_teac…
Nov 16, 2021
37aced2
minor, remove commented out print
Nov 16, 2021
4f91279
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
b05930f
comment
Nov 16, 2021
5086e85
more comment updates (not sure if it actually helps clarity..)
Nov 16, 2021
1e30035
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Nov 16, 2021
9a25fc5
[TOD][Dataset][Easy] Google SGD in TOD Conversations format
Nov 16, 2021
faa2356
[TOD][Dataset][Easyish] Google Simulation Splits
Nov 16, 2021
9426997
[TOD][Datasets][Easy] MetalWoz
Nov 16, 2021
51ed1a9
Merge branch 'main' into simpler_tod_1_core_only
Nov 16, 2021
a6508be
Merge branch 'simpler_tod_1_core_only' into simpler_tod_2_agents_teac…
Nov 16, 2021
eebc36b
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
3675781
use same version of black as in the pre-commit hook
Nov 16, 2021
086c91c
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
0bc961e
use same version of black as in the pre-commit hook
Nov 16, 2021
ed26407
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Nov 16, 2021
677df09
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 16, 2021
24ee898
black with version from pre-commit hook
Nov 16, 2021
3ca7ae3
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 16, 2021
3145e0e
Shouldn't worry about tod_json being in task_list
Nov 16, 2021
1b2a3fb
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 16, 2021
f44b17b
add to task list; run lint with right version of black
Nov 16, 2021
43474c4
Merge branch 'simpler_tod_5a_google_sgd' into simpler_tod_5b_google_s…
Nov 16, 2021
d290ecd
Merge branch 'simpler_tod_5b_google_sgd_sim_splits' into simpler_tod_…
Nov 17, 2021
7c3ccf5
lint with right version
Nov 17, 2021
dfc4989
Merge branch 'main' into simpler_tod_2_agents_teachers
Nov 29, 2021
2f15448
address eric comments; add new readme + more documentation
Nov 30, 2021
abd1c7e
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 30, 2021
5d0197d
minor wording change
Nov 30, 2021
39792a8
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 30, 2021
76bfa89
add more documtnation to world tests (following comment on teacher te…
Nov 30, 2021
73c5c7a
minor comment update
Nov 30, 2021
f6acccb
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Nov 30, 2021
dc4b70e
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Nov 30, 2021
1299b68
Merge branch 'simpler_tod_5a_google_sgd' into simpler_tod_5b_google_s…
Nov 30, 2021
58965d3
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5c_metalwoz
Nov 30, 2021
55aa3ca
Merge branch 'simpler_tod_5b_google_sgd_sim_splits' into simpler_tod_…
Nov 30, 2021
7ab9d70
update to respect actual count of episodes (I think this might have i…
Dec 1, 2021
c6c728d
Merge branch 'main' into simpler_tod_2_agents_teachers
Dec 1, 2021
b3283d0
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Dec 1, 2021
85ab0fd
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Dec 1, 2021
0969aa1
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Dec 1, 2021
1869cee
Merge branch 'simpler_tod_5a_google_sgd' into simpler_tod_5b_google_s…
Dec 1, 2021
609f930
Merge branch 'simpler_tod_5b_google_sgd_sim_splits' into simpler_tod_…
Dec 1, 2021
0580ff0
Merge branch 'main' into simpler_tod_2_agents_teachers
Dec 2, 2021
e00accf
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Dec 2, 2021
701da8d
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Dec 2, 2021
d519dc2
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Dec 2, 2021
c7c1c64
Merge branch 'simpler_tod_5a_google_sgd' into simpler_tod_5b_google_s…
Dec 2, 2021
828f44f
Merge branch 'simpler_tod_5b_google_sgd_sim_splits' into simpler_tod_…
Dec 2, 2021
9466144
regen after changing tod teacher logic to respect episode/examples le…
Dec 2, 2021
1392d99
regen after changing tod teacher logic to respect episode/examples le…
Dec 2, 2021
71b5af8
Merge branch 'simpler_tod_5a_google_sgd' into simpler_tod_5b_google_s…
Dec 2, 2021
9da65a6
Merge branch 'simpler_tod_5b_google_sgd_sim_splits' into simpler_tod_…
Dec 2, 2021
7b24acf
Merge branch 'main' into simpler_tod_3_world
Dec 18, 2021
e3fa063
Merge branch 'simpler_tod_3_world' into simpler_tod_4_tod_json
Dec 18, 2021
2384563
Merge branch 'simpler_tod_4_tod_json' into simpler_tod_5a_google_sgd
Dec 18, 2021
d9ba7e4
Merge branch 'main' into simpler_tod_5a_google_sgd
Dec 22, 2021
acd6ffe
not sure why this comment keeps not being merged correctly ugh...
Dec 22, 2021
a753a6d
Merge branch 'simpler_tod_5a_google_sgd' into simpler_tod_5b_google_s…
Dec 22, 2021
66d8bf8
Merge branch 'simpler_tod_5b_google_sgd_sim_splits' into simpler_tod_…
Dec 22, 2021
0f49cb5
noticed a different in episode lengths between old version of this da…
Dec 22, 2021
0fb3ecb
Merge branch 'main' into simpler_tod_5b_google_sgd_sim_splits
Dec 22, 2021
66e09ee
Merge branch 'simpler_tod_5b_google_sgd_sim_splits' into simpler_tod_…
Dec 22, 2021
00ae154
Merge branch 'main' into simpler_tod_5c_metalwoz
Dec 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def filter_tests_with_circleci(test_list):
('datatests/', 'data'),
('parlai/tasks/', 'teacher'),
('tasks/', 'tasks'),
('tod/', 'tod'),
]


Expand Down
2 changes: 2 additions & 0 deletions parlai/core/teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,8 @@ def num_episodes(self) -> int:
"""
Return the number of episodes in the data.
"""
if hasattr(self, "_num_episodes_cache"):
return self._num_episodes_cache
try:
return self.data.num_episodes()
except AttributeError:
Expand Down
79 changes: 79 additions & 0 deletions parlai/core/tod/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Task-Oriented Dialog (TOD) Core README

For the quickest getting-to-use of TOD classes, start with the "Teachers + Agents Usage" section below (for understanding how to setup agents such that they work with new datasets) and `parlai/scripts/tod_world_script.py` (for understanding how to run simulations with the TOD conversations format).

See `projects/tod_simulator/README` for a higher-level usage-focused README. This document also describes the structure of the contents of this directory.

As a convention, files referenced externally to this directory are prefixed with `tod` whereas those only referenced by other files within the directory are not.

---

# Teachers + Agents Usage

See `tod_agents.py` for the classes.

For a given dataset, extend `TodStructuredDataParser` and implement `generate_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data.

For example, given a `class XX_DataParser(TodStructuredDataParser)`, `class XX_UserSimulatorTeacher(XX_DataParser, TodUserSimulatorTeacher)` would be how one would define a teacher that generates training data for a User Simulator model.

Once the relevant agents have been created (or relevant models have been fine-tuned), see `parlai.scripts.tod_world_script` for generating the simulations themselves.

## Why we do this
These files aid in consistency between Teachers and Agents for simulation. Rather than having to align multiple different agents to be consistent about assuptions about data formatting, tokens, spacing, etc, we do this once (via converting everything to `TodStructuredEpisode`) and let the code handle the rest.

# Description of Agents + Teachers useful for Simulation
## Teachers for training (generative) models
* TodSystemTeacher
* TodUserSimulatorTeacher

## Agents for Grounding
For goal grounding for the User for simulation:
* TodGoalAgent
* Dumps goals as is from the dataset, possibly multiple per episode
* TodSingleGoalAgent
* Flattens goals such that a single one is used to seed a conversation. For datasets that include multiple goals per conversation, each individual goal is used as a seed.

For (optional) API schema grounding for the System:
* TodApiSchemaAgent (must be used with `TodGoalAgent` only)
* TodSingleApiSchemaAgent (must be used with `TodSingleGoalAgent` only)
* EmptyApiSchemaAgent
* Used for simulations where the expectation is `no schema`, ie, evaluation simulations.

## Agents for mocking APIs:
* StandaloneApiAgent
* Assumed to be provided a .pickle file 'trained' by `TodStandaloneApiTeacher`. (See comments in-line on classes for train command example)

# Agents for dumping data from a ground truth dataset
The following are for extracting TOD World metrics from a ground truth dataset. These are generally used sparingly and only for calculating baselines.
* TodApiCallAndSysUttAgent
* TodApiResponseAgent
* TodUserUttAgent

For this metrics extraction, `TodGoalAgent` and `TodApiSchemaAgent` should be used.

# Other agents
There is a `EmptyGoalAgent` for use in human-human conversations where a goal is unnecessary.

---

# Directory contents

This directory is split into 3 main components: files to support agents + teachers, files to support the simulation world, and files to store functionality common to both of these. We describe the common functionality first then go to the other two.

Tests for all files in this directory are stored in `tests/tod`

## Files for common functionality
`tod_core.py` defines consts and enums used across TOD agents, teachers, and world. It also defines dataclasses for storing the intermediate data format used when parsing a dataset to the TOD structure as well as a `SerializationHelper` from going from machine structured data (ex. API Calls) to flattened versions used by the models.


## Files for agents and teachers
Usage of `tod_agents.py` is described above. It references `teacher_metrics.py` which stores Metrics objects.

## Files for simulation world
Description of usage of the simulation world is primarily stored in the script running the world, stored in `parlai/scripts/tod_world_script.py`. The script is responsible for running multiple episodes of simulation and saving simulation output data.

The world itself is stored in `tod_world.py`. The world follows the same intermediate dataformats for episodes as described in `tod_core.py` and does the correct calling of different agents to support this. It is generally recommended that this file not be touched.

A general class for collecting metrics out of `TODWorld` is stored within `world_metrics.py` with individual 'metric handlers' responsible for calculating a given metric stored in `world_metric_handlers.py`.


148 changes: 148 additions & 0 deletions parlai/core/tod/teacher_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Task Oriented Dialogue (TOD) teacher metrics.
"""
from typing import Optional, List, Dict, Any
from parlai.core.metrics import AverageMetric, BleuMetric, F1Metric, Metric, Metrics


class SlotMetrics(Metrics):
"""
Helper container which encapsulates standard slot metrics in task oriented learning
(jga, slot_p, slot_r, etc).

Due to differences in dialogue representations between tasks, the input is pre-
parsed ground truth and predicted slot dictionaries.
"""

def __init__(
self,
teacher_slots: Dict[str, str],
predicted_slots: Dict[str, str],
prefixes: Optional[List] = None,
shared: Dict[str, Any] = None,
) -> None:
super().__init__(shared=shared)
self.prefixes = prefixes if prefixes else []
# jga and optionally Avg(jga,nlg_bleu)
self.add_with_prefixes("jga", AverageMetric(teacher_slots == predicted_slots))
if len(teacher_slots) > 0:
self.add_with_prefixes(
"jga_noempty", AverageMetric(teacher_slots == predicted_slots)
)
else:
self.add_with_prefixes(
"jga_empty", AverageMetric(teacher_slots == predicted_slots)
)

# precision
for pred_slot_name, pred_value in predicted_slots.items():
slot_p = AverageMetric(teacher_slots.get(pred_slot_name) == pred_value)
self.add_with_prefixes("slot_p", slot_p)
self.add_with_prefixes("slot_f1", SlotF1Metric(slot_p=slot_p))
# recall
for teacher_slot_name, teacher_value in teacher_slots.items():
slot_r = AverageMetric(
predicted_slots.get(teacher_slot_name) == teacher_value
)
self.add_with_prefixes("slot_r", slot_r)
self.add_with_prefixes("slot_f1", SlotF1Metric(slot_r=slot_r))

def add_with_prefixes(self, name, value):
self.add(name, value)
for prefix in self.prefixes:
self.add(f"{prefix}/{name}", value)


class NlgMetrics(Metrics):
"""
Helper container for generation version of standard metrics (F1, BLEU, ..).
"""

def __init__(
self,
guess: str,
labels: Optional[List[str]],
prefixes: Optional[List[str]] = None,
shared: Dict[str, Any] = None,
avg_jga_nlg_bleu: bool = False,
) -> None:
super().__init__(shared=shared)
self.prefixes = prefixes if prefixes else []
bleu = BleuMetric.compute(guess, labels)
f1 = F1Metric.compute(guess, labels)
self.add_with_prefixes("nlg_bleu", bleu)
self.add_with_prefixes("nlg_f1", f1)

def add_with_prefixes(self, name, value):
self.add(name, value)
for prefix in self.prefixes:
self.add(f"{prefix}/{name}", value)


AverageType = Optional[AverageMetric]


def _average_type_sum_helper(first: AverageType, second: AverageType) -> AverageType:
"""
Helper to deal with Nones.

We are "clever" in how we aggregate SlotF1Metrics (See SlotMetrics `__init__`) in
that we add precision and recall values separately, but this means we need to handle
None.
"""
if first is None:
return second
if second is None:
return first
return first + second


class SlotF1Metric(Metric):
"""
Metric to keep track of slot F1.

Keeps track of slot precision and slot recall as running metrics.
"""

__slots__ = ("_slot_p", "_slot_r")

@property
def macro_average(self) -> bool:
"""
Indicates whether this metric should be macro-averaged when globally reported.
"""
return True

def __init__(self, slot_p: AverageType = None, slot_r: AverageType = None):
if not isinstance(slot_p, AverageMetric) and slot_p is not None:
slot_p = AverageMetric(slot_p)
if not isinstance(slot_r, AverageMetric) and slot_r is not None:
slot_r = AverageMetric(slot_r)
self._slot_p = slot_p
self._slot_r = slot_r

def __add__(self, other: Optional["SlotF1Metric"]) -> "SlotF1Metric":
# NOTE: hinting can be cleaned up with "from __future__ import annotations" when
# we drop Python 3.6
if other is None:
return self
slot_p = _average_type_sum_helper(self._slot_p, other._slot_p)
slot_r = _average_type_sum_helper(self._slot_r, other._slot_r)
return type(self)(slot_p=slot_p, slot_r=slot_r)

def value(self) -> float:
if self._slot_p is None or self._slot_r is None:
return float("nan")
else:
slot_p = self._slot_p.value()
slot_r = self._slot_r.value()
if slot_p == 0.0 and slot_r == 0.0:
return float("nan")
else:
return 2 * (slot_p * slot_r) / (slot_p + slot_r)
Loading