Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Add HuggingFace datasets support #3570

Merged
merged 24 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions parlai/tasks/glue/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Task: Glue
===============
Description: GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems. This task uses the `AbstractHuggingFaceTeacher` to load the dataset.

Websites: https://huggingface.co/ and https://gluebenchmark.com/
5 changes: 5 additions & 0 deletions parlai/tasks/glue/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
19 changes: 19 additions & 0 deletions parlai/tasks/glue/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.tasks.huggingface.agents import AbstractHuggingFaceTeacher


class ColaTeacher(AbstractHuggingFaceTeacher):
hf_path = 'glue'
hf_name = 'cola'
hf_text_fields = ['sentence']
hf_label_field = 'label'
hf_splits_mapping = {'train': 'train', 'valid': 'validation', 'test': 'test'}


class DefaultTeacher(ColaTeacher):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like us to add the rest of GLUE, but that can be the next PR if you want

pass
6 changes: 6 additions & 0 deletions parlai/tasks/huggingface/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Task: HuggingFace
===============
Description: Can load HuggingFace datasets.

Website: https://huggingface.co/

5 changes: 5 additions & 0 deletions parlai/tasks/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
131 changes: 131 additions & 0 deletions parlai/tasks/huggingface/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.core.teachers import DialogTeacher
from parlai.utils.data import DatatypeHelper
from typing import Dict, Iterable, List, Optional, Tuple
from typing_extensions import TypedDict
import os

# huggingface imports
from datasets import load_dataset


class SplitsMappingDict(TypedDict):
train: str
valid: str
test: str


class AbstractHuggingFaceTeacher(DialogTeacher):
"""
Abstract parent class for HuggingFace teachers. Extend this class and specify the
attributes below to use a different dataset.

hf_path = path parameter passed into hugging face load_dataset function
hf_name = name parameter passed into hugging face load_dataset function
hf_text_fields = list of names of the data fields from the dataset to be included in the text/query
hf_label_field = name of the data field from the hf dataset that specifies the label of the episode
hf_splits_mapping = dictionary mapping with the keys 'train', 'valid', and 'test', that map to the
names of the splits of the hf dataset.
render_text_field = bool where if True, will include the text field name in the query (e.g. "sentence: <sentence>")
"""

def __init__(self, opt, shared=None):
self.fold = DatatypeHelper.fold(opt['datatype'])
self.hf_split = self.hf_splits_mapping[self.fold]
self.data_path = self._path(opt)
opt['datafile'] = self.data_path

self.id = "huggingface"
super().__init__(opt, shared)

def _path(self, opt):
if self.hf_name:
return os.path.join(
opt['datapath'], 'huggingface', self.hf_path, self.hf_name, self.fold
)
return os.path.join(opt['datapath'], 'huggingface', self.hf_path, self.fold)

@property
def hf_path(self) -> str:
raise NotImplementedError

@property
def hf_name(self) -> Optional[str]:
return None

@property
def hf_text_fields(self) -> List[str]:
raise NotImplementedError

@property
def hf_label_field(self) -> str:
raise NotImplementedError

@property
def hf_splits_mapping(self) -> SplitsMappingDict:
raise NotImplementedError

def _get_text_value(self, row) -> Tuple[str, Dict[str, str]]:
"""
return the constructed text query and dict mapping text field names to values.
"""
# construct text query from the hf_text_fields specified
text_dict = {}
for col in self.hf_text_fields:
text_part = row.get(col)
if text_part is None:
raise KeyError(f'Feature "{col}" not found in data.')
text_dict[col] = text_part
return '\n'.join(text_dict.values()), text_dict

def _get_label_value(self, row):
"""
return the label value from the data row.
"""
return row[self.hf_label_field]

def _get_label_candidates(self, row, label) -> str:
"""
try to return the true label text value from the row and the candidates.
"""
pre_candidates = self.dataset.features[self.hf_label_field].names
# construct label and candidates
if type(label) is int:
return pre_candidates[label], pre_candidates
if label in row:
return row[label], [row[l] for l in pre_candidates]
return label, pre_candidates

def setup_data(self, path: str) -> Iterable[tuple]:
"""
Default implementation of setup_data.

Manually override if needed.
"""
# load dataset from HuggingFace
self.dataset = load_dataset(
path=self.hf_path, name=self.hf_name, split=self.hf_split
)

for row in self.dataset:
query, text_dict = self._get_text_value(row)
label = self._get_label_value(row)
label, candidates = self._get_label_candidates(row, label)

episode_dict = text_dict
episode_dict['text'] = query
episode_dict['label'] = label
episode_dict['label_candidates'] = candidates
yield episode_dict, True


class DefaultTeacher:
def __init__(self, opt):
raise NotImplementedError(
"There is no default teacher for HuggingFace datasets. Please use a specific one."
)
21 changes: 21 additions & 0 deletions parlai/tasks/task_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,19 @@
),
"links": {"arXiv": "https://arxiv.org/abs/1706.05125"},
},
{
"id": "Glue",
"display_name": "Glue",
"task": "glue",
"tags": [],
"description": (
"GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems."
),
"links": {
"website": "https://gluebenchmark.com/",
"website2": "https://huggingface.co/datasets/glue",
},
},
{
"id": "HotpotQA",
"display_name": "HotpotQA",
Expand All @@ -303,6 +316,14 @@
),
"links": {"arXiv": "https://arxiv.org/abs/1809.09600"},
},
{
"id": "HuggingFace",
"display_name": "HuggingFace",
"task": "huggingface",
"tags": [],
"description": ("HuggingFace datasets"),
"links": {"website": "https://huggingface.co/"},
},
{
"id": "LIGHT-Dialogue",
"display_name": "LIGHT-Dialogue",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
boto3==1.9.246
botocore==1.12.246
coloredlogs==14.0
datasets==1.4.1
docutils<0.16,>=0.14
emoji==0.5.4
docformatter==1.3.0
Expand Down Expand Up @@ -42,7 +43,7 @@ tensorboardX==2.1
tokenizers>=0.8.0
torchtext>=0.5.0
tornado==6.0.4
tqdm==4.36.1
tqdm~=4.36.1
typing-extensions==3.7.4.1
Unidecode==1.1.1
urllib3~=1.25.9
Expand Down