This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HuggingFace datasets support (#3570)
* add huggingface to tasks list * huggingface task for glue dataset support * handle splits * readme * string labels instead of numerical labels * identifying text, labels, cands for setup data * update test: * add huggingface datasets module to requirements * abstract class * abstract class and add args to specify what features for query and label in parlai * add glue teacher * requirements tqdm edit * formatting * update docutils for cleaninstall * update docutils for cleaninstall * abstracthuggingfaceteacher, instead of taking in info through args, move this to the development side * add hugging face splits mapping property * remove unused _path and moved load_dataset to setup_data * add render_text_field * add properties to abstract class and methods * create glue task * update the hf agent and tasklist
- Loading branch information
Showing
8 changed files
with
194 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Task: Glue | ||
=============== | ||
Description: GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems. This task uses the `AbstractHuggingFaceTeacher` to load the dataset. | ||
|
||
Websites: https://huggingface.co/ and https://gluebenchmark.com/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.tasks.huggingface.agents import AbstractHuggingFaceTeacher | ||
|
||
|
||
class ColaTeacher(AbstractHuggingFaceTeacher): | ||
hf_path = 'glue' | ||
hf_name = 'cola' | ||
hf_text_fields = ['sentence'] | ||
hf_label_field = 'label' | ||
hf_splits_mapping = {'train': 'train', 'valid': 'validation', 'test': 'test'} | ||
|
||
|
||
class DefaultTeacher(ColaTeacher): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Task: HuggingFace | ||
=============== | ||
Description: Can load HuggingFace datasets. | ||
|
||
Website: https://huggingface.co/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.core.teachers import DialogTeacher | ||
from parlai.utils.data import DatatypeHelper | ||
from typing import Dict, Iterable, List, Optional, Tuple | ||
from typing_extensions import TypedDict | ||
import os | ||
|
||
# huggingface imports | ||
from datasets import load_dataset | ||
|
||
|
||
class SplitsMappingDict(TypedDict): | ||
train: str | ||
valid: str | ||
test: str | ||
|
||
|
||
class AbstractHuggingFaceTeacher(DialogTeacher): | ||
""" | ||
Abstract parent class for HuggingFace teachers. Extend this class and specify the | ||
attributes below to use a different dataset. | ||
hf_path = path parameter passed into hugging face load_dataset function | ||
hf_name = name parameter passed into hugging face load_dataset function | ||
hf_text_fields = list of names of the data fields from the dataset to be included in the text/query | ||
hf_label_field = name of the data field from the hf dataset that specifies the label of the episode | ||
hf_splits_mapping = dictionary mapping with the keys 'train', 'valid', and 'test', that map to the | ||
names of the splits of the hf dataset. | ||
render_text_field = bool where if True, will include the text field name in the query (e.g. "sentence: <sentence>") | ||
""" | ||
|
||
def __init__(self, opt, shared=None): | ||
self.fold = DatatypeHelper.fold(opt['datatype']) | ||
self.hf_split = self.hf_splits_mapping[self.fold] | ||
self.data_path = self._path(opt) | ||
opt['datafile'] = self.data_path | ||
|
||
self.id = "huggingface" | ||
super().__init__(opt, shared) | ||
|
||
def _path(self, opt): | ||
if self.hf_name: | ||
return os.path.join( | ||
opt['datapath'], 'huggingface', self.hf_path, self.hf_name, self.fold | ||
) | ||
return os.path.join(opt['datapath'], 'huggingface', self.hf_path, self.fold) | ||
|
||
@property | ||
def hf_path(self) -> str: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_name(self) -> Optional[str]: | ||
return None | ||
|
||
@property | ||
def hf_text_fields(self) -> List[str]: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_label_field(self) -> str: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_splits_mapping(self) -> SplitsMappingDict: | ||
raise NotImplementedError | ||
|
||
def _get_text_value(self, row) -> Tuple[str, Dict[str, str]]: | ||
""" | ||
return the constructed text query and dict mapping text field names to values. | ||
""" | ||
# construct text query from the hf_text_fields specified | ||
text_dict = {} | ||
for col in self.hf_text_fields: | ||
text_part = row.get(col) | ||
if text_part is None: | ||
raise KeyError(f'Feature "{col}" not found in data.') | ||
text_dict[col] = text_part | ||
return '\n'.join(text_dict.values()), text_dict | ||
|
||
def _get_label_value(self, row): | ||
""" | ||
return the label value from the data row. | ||
""" | ||
return row[self.hf_label_field] | ||
|
||
def _get_label_candidates(self, row, label) -> str: | ||
""" | ||
try to return the true label text value from the row and the candidates. | ||
""" | ||
pre_candidates = self.dataset.features[self.hf_label_field].names | ||
# construct label and candidates | ||
if type(label) is int: | ||
return pre_candidates[label], pre_candidates | ||
if label in row: | ||
return row[label], [row[l] for l in pre_candidates] | ||
return label, pre_candidates | ||
|
||
def setup_data(self, path: str) -> Iterable[tuple]: | ||
""" | ||
Default implementation of setup_data. | ||
Manually override if needed. | ||
""" | ||
# load dataset from HuggingFace | ||
self.dataset = load_dataset( | ||
path=self.hf_path, name=self.hf_name, split=self.hf_split | ||
) | ||
|
||
for row in self.dataset: | ||
query, text_dict = self._get_text_value(row) | ||
label = self._get_label_value(row) | ||
label, candidates = self._get_label_candidates(row, label) | ||
|
||
episode_dict = text_dict | ||
episode_dict['text'] = query | ||
episode_dict['label'] = label | ||
episode_dict['label_candidates'] = candidates | ||
yield episode_dict, True | ||
|
||
|
||
class DefaultTeacher: | ||
def __init__(self, opt): | ||
raise NotImplementedError( | ||
"There is no default teacher for HuggingFace datasets. Please use a specific one." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters