Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add HuggingFace datasets support #3570

Merged
merged 24 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a4abc8c
add huggingface to tasks list
meganung Mar 23, 2021
82ae161
huggingface task for glue dataset support
meganung Mar 23, 2021
69ce097
handle splits
meganung Mar 23, 2021
e5dc0c7
readme
meganung Mar 24, 2021
d08d6cb
string labels instead of numerical labels
meganung Mar 24, 2021
b22c3a1
identifying text, labels, cands for setup data
meganung Mar 25, 2021
fbd5657
update test:
meganung Mar 25, 2021
d04fc23
add huggingface datasets module to requirements
meganung Mar 25, 2021
504a055
abstract class
meganung Apr 5, 2021
27881b9
abstract class and add args to specify what features for query and la…
meganung Apr 5, 2021
ac08446
add glue teacher
meganung Apr 5, 2021
682c43d
requirements tqdm edit
meganung Apr 5, 2021
d699aa5
formatting
meganung Apr 7, 2021
3488049
update docutils for cleaninstall
meganung Apr 7, 2021
23ea556
update docutils for cleaninstall
meganung Apr 7, 2021
a192819
abstracthuggingfaceteacher, instead of taking in info through args, m…
meganung Apr 16, 2021
ca9a3af
add hugging face splits mapping property
meganung Apr 16, 2021
3e77caf
remove unused _path and moved load_dataset to setup_data
meganung Apr 20, 2021
53aaeb5
Merge branch 'master' into add-hf
meganung Apr 20, 2021
efadb6d
add render_text_field
meganung Apr 20, 2021
475b0f6
Merge branch 'add-hf' of github.com:facebookresearch/ParlAI into add-hf
meganung Apr 20, 2021
c3cc2c4
add properties to abstract class and methods
meganung Apr 22, 2021
d762620
create glue task
meganung Apr 23, 2021
d5d373b
update the hf agent and tasklist
meganung Apr 26, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions parlai/tasks/glue/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Task: Glue
===============
Description: GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems. This task uses the `AbstractHuggingFaceTeacher` to load the dataset.

Websites: https://huggingface.co/ and https://gluebenchmark.com/
5 changes: 5 additions & 0 deletions parlai/tasks/glue/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
19 changes: 19 additions & 0 deletions parlai/tasks/glue/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.tasks.huggingface.agents import AbstractHuggingFaceTeacher


class ColaTeacher(AbstractHuggingFaceTeacher):
hf_path = 'glue'
hf_name = 'cola'
hf_text_fields = ['sentence']
hf_label_field = 'label'
hf_splits_mapping = {'train': 'train', 'valid': 'validation', 'test': 'test'}


class DefaultTeacher(ColaTeacher):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like us to add the rest of GLUE, but that can be the next PR if you want

pass
11 changes: 11 additions & 0 deletions parlai/tasks/glue/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.utils.testing import AutoTeacherTest


class TestDefaultTeacher(AutoTeacherTest):
task = 'glue:cola'
6 changes: 6 additions & 0 deletions parlai/tasks/huggingface/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Task: HuggingFace
===============
Description: Can load HuggingFace datasets.

Website: https://huggingface.co/

5 changes: 5 additions & 0 deletions parlai/tasks/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
128 changes: 128 additions & 0 deletions parlai/tasks/huggingface/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.core.teachers import DialogTeacher
from parlai.utils.data import DatatypeHelper
from typing import Dict, Iterable, List, Optional, Tuple
from typing_extensions import TypedDict
import os

# huggingface imports
from datasets import load_dataset


class SplitsMappingDict(TypedDict):
train: str
valid: str
test: str


class AbstractHuggingFaceTeacher(DialogTeacher):
"""
Abstract parent class for HuggingFace teachers. Extend this class and specify the
attributes below to use a different dataset.

hf_path = path parameter passed into hugging face load_dataset function
hf_name = name parameter passed into hugging face load_dataset function
hf_text_fields = list of names of the data fields from the dataset to be included in the text/query
hf_label_field = name of the data field from the hf dataset that specifies the label of the episode
hf_splits_mapping = dictionary mapping with the keys 'train', 'valid', and 'test', that map to the
names of the splits of the hf dataset.
render_text_field = bool where if True, will include the text field name in the query (e.g. "sentence: <sentence>")
"""

def __init__(self, opt, shared=None):
self.fold = DatatypeHelper.fold(opt['datatype'])
self.hf_split = self.hf_splits_mapping[self.fold]
self.data_path = self._path(opt)
opt['datafile'] = self.data_path
stephenroller marked this conversation as resolved.
Show resolved Hide resolved

self.id = "huggingface"
super().__init__(opt, shared)
self.reset()

def _path(self, opt):
if self.hf_name:
return os.path.join(
opt['datapath'], 'huggingface', self.hf_path, self.hf_name, self.fold
)
return os.path.join(opt['datapath'], 'huggingface', self.hf_path, self.fold)

@property
def hf_path(self) -> str:
raise NotImplementedError

@property
def hf_name(self) -> Optional[str]:
return None

@property
def hf_text_fields(self) -> List[str]:
raise NotImplementedError

@property
def hf_label_field(self) -> str:
raise NotImplementedError

@property
def hf_splits_mapping(self) -> SplitsMappingDict:
raise NotImplementedError

def setup_data(self, path: str) -> Iterable[tuple]:
"""
Default implementation of setup_data.

Manually override if needed.
"""

def _get_text_value(row) -> Tuple[str, Dict[str, str]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these into the class so they're overridable

"""
return the constructed text query and dict mapping text field names to
values.
"""
# construct text query from the hf_text_fields specified
text_dict = {}
for col in self.hf_text_fields:
text_part = row.get(col)
if text_part is None:
raise KeyError(f'Feature "{col}" not found in data.')
text_dict[col] = text_part
return '\n'.join(text_dict.values()), text_dict

def _get_label_value(row):
return row[self.hf_label_field]

def _get_label_candidates(row, label) -> str:
pre_candidates = dataset.features[self.hf_label_field].names
# construct label and candidates
if type(label) is int:
return pre_candidates[label], pre_candidates
if label in row:
return row[label], [row[l] for l in pre_candidates]
return label, pre_candidates

# load dataset from HuggingFace
dataset = load_dataset(
path=self.hf_path, name=self.hf_name, split=self.hf_split
)

for row in dataset:
query, text_dict = _get_text_value(row)
label = _get_label_value(row)
label, candidates = _get_label_candidates(row, label)

episode_dict = text_dict
episode_dict['text'] = query
episode_dict['label'] = label
episode_dict['label_candidates'] = candidates
yield episode_dict, True


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carve out a new GlueColaTeacher that fills in the particular values for that task

class DefaultTeacher(AbstractHuggingFaceTeacher):
def __init__():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a standard constructor

raise NotImplementedError(
"There is no default teacher for HuggingFace datasets. Please use a specific one."
)
13 changes: 13 additions & 0 deletions parlai/tasks/task_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,19 @@
),
"links": {"arXiv": "https://arxiv.org/abs/1706.05125"},
},
{
"id": "Glue",
"display_name": "Glue",
"task": "glue",
"tags": [],
"description": (
"GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems."
),
"links": {
"website": "https://gluebenchmark.com/",
"website2": "https://huggingface.co/datasets/glue",
},
},
{
"id": "HotpotQA",
"display_name": "HotpotQA",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
boto3==1.9.246
botocore==1.12.246
coloredlogs==14.0
datasets==1.4.1
docutils<0.16,>=0.14
emoji==0.5.4
docformatter==1.3.0
Expand Down Expand Up @@ -42,7 +43,7 @@ tensorboardX==2.1
tokenizers>=0.8.0
torchtext>=0.5.0
tornado==6.0.4
tqdm==4.36.1
tqdm~=4.36.1
typing-extensions==3.7.4.1
Unidecode==1.1.1
urllib3~=1.25.9
Expand Down