This repository was archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add HuggingFace datasets support #3570
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
a4abc8c
add huggingface to tasks list
82ae161
huggingface task for glue dataset support
69ce097
handle splits
e5dc0c7
readme
d08d6cb
string labels instead of numerical labels
b22c3a1
identifying text, labels, cands for setup data
fbd5657
update test:
d04fc23
add huggingface datasets module to requirements
504a055
abstract class
27881b9
abstract class and add args to specify what features for query and la…
ac08446
add glue teacher
682c43d
requirements tqdm edit
d699aa5
formatting
3488049
update docutils for cleaninstall
23ea556
update docutils for cleaninstall
a192819
abstracthuggingfaceteacher, instead of taking in info through args, m…
ca9a3af
add hugging face splits mapping property
3e77caf
remove unused _path and moved load_dataset to setup_data
53aaeb5
Merge branch 'master' into add-hf
meganung efadb6d
add render_text_field
475b0f6
Merge branch 'add-hf' of github.com:facebookresearch/ParlAI into add-hf
c3cc2c4
add properties to abstract class and methods
d762620
create glue task
d5d373b
update the hf agent and tasklist
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Task: Glue | ||
=============== | ||
Description: GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems. This task uses the `AbstractHuggingFaceTeacher` to load the dataset. | ||
|
||
Websites: https://huggingface.co/ and https://gluebenchmark.com/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.tasks.huggingface.agents import AbstractHuggingFaceTeacher | ||
|
||
|
||
class ColaTeacher(AbstractHuggingFaceTeacher): | ||
hf_path = 'glue' | ||
hf_name = 'cola' | ||
hf_text_fields = ['sentence'] | ||
hf_label_field = 'label' | ||
hf_splits_mapping = {'train': 'train', 'valid': 'validation', 'test': 'test'} | ||
|
||
|
||
class DefaultTeacher(ColaTeacher): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Task: HuggingFace | ||
=============== | ||
Description: Can load HuggingFace datasets. | ||
|
||
Website: https://huggingface.co/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.core.teachers import DialogTeacher | ||
from parlai.utils.data import DatatypeHelper | ||
from typing import Dict, Iterable, List, Optional, Tuple | ||
from typing_extensions import TypedDict | ||
import os | ||
|
||
# huggingface imports | ||
from datasets import load_dataset | ||
|
||
|
||
class SplitsMappingDict(TypedDict): | ||
train: str | ||
valid: str | ||
test: str | ||
|
||
|
||
class AbstractHuggingFaceTeacher(DialogTeacher): | ||
""" | ||
Abstract parent class for HuggingFace teachers. Extend this class and specify the | ||
attributes below to use a different dataset. | ||
|
||
hf_path = path parameter passed into hugging face load_dataset function | ||
hf_name = name parameter passed into hugging face load_dataset function | ||
hf_text_fields = list of names of the data fields from the dataset to be included in the text/query | ||
hf_label_field = name of the data field from the hf dataset that specifies the label of the episode | ||
hf_splits_mapping = dictionary mapping with the keys 'train', 'valid', and 'test', that map to the | ||
names of the splits of the hf dataset. | ||
render_text_field = bool where if True, will include the text field name in the query (e.g. "sentence: <sentence>") | ||
""" | ||
|
||
def __init__(self, opt, shared=None): | ||
self.fold = DatatypeHelper.fold(opt['datatype']) | ||
self.hf_split = self.hf_splits_mapping[self.fold] | ||
self.data_path = self._path(opt) | ||
opt['datafile'] = self.data_path | ||
stephenroller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.id = "huggingface" | ||
super().__init__(opt, shared) | ||
|
||
def _path(self, opt): | ||
if self.hf_name: | ||
return os.path.join( | ||
opt['datapath'], 'huggingface', self.hf_path, self.hf_name, self.fold | ||
) | ||
return os.path.join(opt['datapath'], 'huggingface', self.hf_path, self.fold) | ||
|
||
@property | ||
def hf_path(self) -> str: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_name(self) -> Optional[str]: | ||
return None | ||
|
||
@property | ||
def hf_text_fields(self) -> List[str]: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_label_field(self) -> str: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_splits_mapping(self) -> SplitsMappingDict: | ||
raise NotImplementedError | ||
|
||
def _get_text_value(self, row) -> Tuple[str, Dict[str, str]]: | ||
""" | ||
return the constructed text query and dict mapping text field names to values. | ||
""" | ||
# construct text query from the hf_text_fields specified | ||
text_dict = {} | ||
for col in self.hf_text_fields: | ||
text_part = row.get(col) | ||
if text_part is None: | ||
raise KeyError(f'Feature "{col}" not found in data.') | ||
text_dict[col] = text_part | ||
return '\n'.join(text_dict.values()), text_dict | ||
|
||
def _get_label_value(self, row): | ||
""" | ||
return the label value from the data row. | ||
""" | ||
return row[self.hf_label_field] | ||
|
||
def _get_label_candidates(self, row, label) -> str: | ||
""" | ||
try to return the true label text value from the row and the candidates. | ||
""" | ||
pre_candidates = self.dataset.features[self.hf_label_field].names | ||
# construct label and candidates | ||
if type(label) is int: | ||
return pre_candidates[label], pre_candidates | ||
if label in row: | ||
return row[label], [row[l] for l in pre_candidates] | ||
return label, pre_candidates | ||
|
||
def setup_data(self, path: str) -> Iterable[tuple]: | ||
""" | ||
Default implementation of setup_data. | ||
|
||
Manually override if needed. | ||
""" | ||
# load dataset from HuggingFace | ||
self.dataset = load_dataset( | ||
path=self.hf_path, name=self.hf_name, split=self.hf_split | ||
) | ||
|
||
for row in self.dataset: | ||
query, text_dict = self._get_text_value(row) | ||
label = self._get_label_value(row) | ||
label, candidates = self._get_label_candidates(row, label) | ||
|
||
episode_dict = text_dict | ||
episode_dict['text'] = query | ||
episode_dict['label'] = label | ||
episode_dict['label_candidates'] = candidates | ||
yield episode_dict, True | ||
|
||
|
||
class DefaultTeacher: | ||
def __init__(self, opt): | ||
raise NotImplementedError( | ||
"There is no default teacher for HuggingFace datasets. Please use a specific one." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like us to add the rest of GLUE, but that can be the next PR if you want