Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Add HuggingFace datasets support (#3570)
Browse files Browse the repository at this point in the history
* add huggingface to tasks list

* huggingface task for glue dataset support

* handle splits

* readme

* string labels instead of numerical labels

* identifying text, labels, cands for setup data

* update test:

* add huggingface datasets module to requirements

* abstract class

* abstract class and add args to specify what features for query and label in parlai

* add glue teacher

* requirements tqdm edit

* formatting

* update docutils for cleaninstall

* update docutils for cleaninstall

* abstracthuggingfaceteacher, instead of taking in info through args, move this to the development side

* add hugging face splits mapping property

* remove unused _path and moved load_dataset to setup_data

* add render_text_field

* add properties to abstract class and methods

* create glue task

* update the hf agent and tasklist
  • Loading branch information
meganung authored Apr 26, 2021
1 parent ef68eb0 commit 6120596
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 1 deletion.
5 changes: 5 additions & 0 deletions parlai/tasks/glue/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Task: Glue
===============
Description: GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems. This task uses the `AbstractHuggingFaceTeacher` to load the dataset.

Websites: https://huggingface.co/ and https://gluebenchmark.com/
5 changes: 5 additions & 0 deletions parlai/tasks/glue/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
19 changes: 19 additions & 0 deletions parlai/tasks/glue/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.tasks.huggingface.agents import AbstractHuggingFaceTeacher


class ColaTeacher(AbstractHuggingFaceTeacher):
hf_path = 'glue'
hf_name = 'cola'
hf_text_fields = ['sentence']
hf_label_field = 'label'
hf_splits_mapping = {'train': 'train', 'valid': 'validation', 'test': 'test'}


class DefaultTeacher(ColaTeacher):
pass
6 changes: 6 additions & 0 deletions parlai/tasks/huggingface/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Task: HuggingFace
===============
Description: Can load HuggingFace datasets.

Website: https://huggingface.co/

5 changes: 5 additions & 0 deletions parlai/tasks/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
131 changes: 131 additions & 0 deletions parlai/tasks/huggingface/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from parlai.core.teachers import DialogTeacher
from parlai.utils.data import DatatypeHelper
from typing import Dict, Iterable, List, Optional, Tuple
from typing_extensions import TypedDict
import os

# huggingface imports
from datasets import load_dataset


class SplitsMappingDict(TypedDict):
train: str
valid: str
test: str


class AbstractHuggingFaceTeacher(DialogTeacher):
"""
Abstract parent class for HuggingFace teachers. Extend this class and specify the
attributes below to use a different dataset.
hf_path = path parameter passed into hugging face load_dataset function
hf_name = name parameter passed into hugging face load_dataset function
hf_text_fields = list of names of the data fields from the dataset to be included in the text/query
hf_label_field = name of the data field from the hf dataset that specifies the label of the episode
hf_splits_mapping = dictionary mapping with the keys 'train', 'valid', and 'test', that map to the
names of the splits of the hf dataset.
render_text_field = bool where if True, will include the text field name in the query (e.g. "sentence: <sentence>")
"""

def __init__(self, opt, shared=None):
self.fold = DatatypeHelper.fold(opt['datatype'])
self.hf_split = self.hf_splits_mapping[self.fold]
self.data_path = self._path(opt)
opt['datafile'] = self.data_path

self.id = "huggingface"
super().__init__(opt, shared)

def _path(self, opt):
if self.hf_name:
return os.path.join(
opt['datapath'], 'huggingface', self.hf_path, self.hf_name, self.fold
)
return os.path.join(opt['datapath'], 'huggingface', self.hf_path, self.fold)

@property
def hf_path(self) -> str:
raise NotImplementedError

@property
def hf_name(self) -> Optional[str]:
return None

@property
def hf_text_fields(self) -> List[str]:
raise NotImplementedError

@property
def hf_label_field(self) -> str:
raise NotImplementedError

@property
def hf_splits_mapping(self) -> SplitsMappingDict:
raise NotImplementedError

def _get_text_value(self, row) -> Tuple[str, Dict[str, str]]:
"""
return the constructed text query and dict mapping text field names to values.
"""
# construct text query from the hf_text_fields specified
text_dict = {}
for col in self.hf_text_fields:
text_part = row.get(col)
if text_part is None:
raise KeyError(f'Feature "{col}" not found in data.')
text_dict[col] = text_part
return '\n'.join(text_dict.values()), text_dict

def _get_label_value(self, row):
"""
return the label value from the data row.
"""
return row[self.hf_label_field]

def _get_label_candidates(self, row, label) -> str:
"""
try to return the true label text value from the row and the candidates.
"""
pre_candidates = self.dataset.features[self.hf_label_field].names
# construct label and candidates
if type(label) is int:
return pre_candidates[label], pre_candidates
if label in row:
return row[label], [row[l] for l in pre_candidates]
return label, pre_candidates

def setup_data(self, path: str) -> Iterable[tuple]:
"""
Default implementation of setup_data.
Manually override if needed.
"""
# load dataset from HuggingFace
self.dataset = load_dataset(
path=self.hf_path, name=self.hf_name, split=self.hf_split
)

for row in self.dataset:
query, text_dict = self._get_text_value(row)
label = self._get_label_value(row)
label, candidates = self._get_label_candidates(row, label)

episode_dict = text_dict
episode_dict['text'] = query
episode_dict['label'] = label
episode_dict['label_candidates'] = candidates
yield episode_dict, True


class DefaultTeacher:
def __init__(self, opt):
raise NotImplementedError(
"There is no default teacher for HuggingFace datasets. Please use a specific one."
)
21 changes: 21 additions & 0 deletions parlai/tasks/task_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,19 @@
),
"links": {"arXiv": "https://arxiv.org/abs/1706.05125"},
},
{
"id": "Glue",
"display_name": "Glue",
"task": "glue",
"tags": [],
"description": (
"GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems."
),
"links": {
"website": "https://gluebenchmark.com/",
"website2": "https://huggingface.co/datasets/glue",
},
},
{
"id": "HotpotQA",
"display_name": "HotpotQA",
Expand All @@ -303,6 +316,14 @@
),
"links": {"arXiv": "https://arxiv.org/abs/1809.09600"},
},
{
"id": "HuggingFace",
"display_name": "HuggingFace",
"task": "huggingface",
"tags": [],
"description": ("HuggingFace datasets"),
"links": {"website": "https://huggingface.co/"},
},
{
"id": "LIGHT-Dialogue",
"display_name": "LIGHT-Dialogue",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
boto3==1.9.246
botocore==1.12.246
coloredlogs==14.0
datasets==1.4.1
docutils<0.16,>=0.14
emoji==0.5.4
docformatter==1.3.0
Expand Down Expand Up @@ -42,7 +43,7 @@ tensorboardX==2.1
tokenizers>=0.8.0
torchtext>=0.5.0
tornado==6.0.4
tqdm==4.36.1
tqdm~=4.36.1
typing-extensions==3.7.4.1
Unidecode==1.1.1
urllib3~=1.25.9
Expand Down

0 comments on commit 6120596

Please sign in to comment.