This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add HuggingFace datasets support #3570
Merged
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
a4abc8c
add huggingface to tasks list
meganung 82ae161
huggingface task for glue dataset support
meganung 69ce097
handle splits
meganung e5dc0c7
readme
meganung d08d6cb
string labels instead of numerical labels
meganung b22c3a1
identifying text, labels, cands for setup data
meganung fbd5657
update test:
meganung d04fc23
add huggingface datasets module to requirements
meganung 504a055
abstract class
meganung 27881b9
abstract class and add args to specify what features for query and la…
meganung ac08446
add glue teacher
meganung 682c43d
requirements tqdm edit
meganung d699aa5
formatting
meganung 3488049
update docutils for cleaninstall
meganung 23ea556
update docutils for cleaninstall
meganung a192819
abstracthuggingfaceteacher, instead of taking in info through args, m…
meganung ca9a3af
add hugging face splits mapping property
meganung 3e77caf
remove unused _path and moved load_dataset to setup_data
meganung 53aaeb5
Merge branch 'master' into add-hf
meganung efadb6d
add render_text_field
meganung 475b0f6
Merge branch 'add-hf' of github.com:facebookresearch/ParlAI into add-hf
meganung c3cc2c4
add properties to abstract class and methods
meganung d762620
create glue task
meganung d5d373b
update the hf agent and tasklist
meganung File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Task: Glue | ||
=============== | ||
Description: GLUE, the General Language Understanding Evaluation benchmark is a collection of resources for training, evaluating, and analyzing natural language understanding systems. This task uses the `AbstractHuggingFaceTeacher` to load the dataset. | ||
|
||
Websites: https://huggingface.co/ and https://gluebenchmark.com/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.tasks.huggingface.agents import AbstractHuggingFaceTeacher | ||
|
||
|
||
class ColaTeacher(AbstractHuggingFaceTeacher): | ||
hf_path = 'glue' | ||
hf_name = 'cola' | ||
hf_text_fields = ['sentence'] | ||
hf_label_field = 'label' | ||
hf_splits_mapping = {'train': 'train', 'valid': 'validation', 'test': 'test'} | ||
|
||
|
||
class DefaultTeacher(ColaTeacher): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.utils.testing import AutoTeacherTest | ||
|
||
|
||
class TestDefaultTeacher(AutoTeacherTest): | ||
task = 'glue:cola' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Task: HuggingFace | ||
=============== | ||
Description: Can load HuggingFace datasets. | ||
|
||
Website: https://huggingface.co/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from parlai.core.teachers import DialogTeacher | ||
from parlai.utils.data import DatatypeHelper | ||
from typing import Dict, Iterable, List, Optional, Tuple | ||
from typing_extensions import TypedDict | ||
import os | ||
|
||
# huggingface imports | ||
from datasets import load_dataset | ||
|
||
|
||
class SplitsMappingDict(TypedDict): | ||
train: str | ||
valid: str | ||
test: str | ||
|
||
|
||
class AbstractHuggingFaceTeacher(DialogTeacher): | ||
""" | ||
Abstract parent class for HuggingFace teachers. Extend this class and specify the | ||
attributes below to use a different dataset. | ||
|
||
hf_path = path parameter passed into hugging face load_dataset function | ||
hf_name = name parameter passed into hugging face load_dataset function | ||
hf_text_fields = list of names of the data fields from the dataset to be included in the text/query | ||
hf_label_field = name of the data field from the hf dataset that specifies the label of the episode | ||
hf_splits_mapping = dictionary mapping with the keys 'train', 'valid', and 'test', that map to the | ||
names of the splits of the hf dataset. | ||
render_text_field = bool where if True, will include the text field name in the query (e.g. "sentence: <sentence>") | ||
""" | ||
|
||
def __init__(self, opt, shared=None): | ||
self.fold = DatatypeHelper.fold(opt['datatype']) | ||
self.hf_split = self.hf_splits_mapping[self.fold] | ||
self.data_path = self._path(opt) | ||
opt['datafile'] = self.data_path | ||
stephenroller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.id = "huggingface" | ||
super().__init__(opt, shared) | ||
self.reset() | ||
|
||
def _path(self, opt): | ||
if self.hf_name: | ||
return os.path.join( | ||
opt['datapath'], 'huggingface', self.hf_path, self.hf_name, self.fold | ||
) | ||
return os.path.join(opt['datapath'], 'huggingface', self.hf_path, self.fold) | ||
|
||
@property | ||
def hf_path(self) -> str: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_name(self) -> Optional[str]: | ||
return None | ||
|
||
@property | ||
def hf_text_fields(self) -> List[str]: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_label_field(self) -> str: | ||
raise NotImplementedError | ||
|
||
@property | ||
def hf_splits_mapping(self) -> SplitsMappingDict: | ||
raise NotImplementedError | ||
|
||
def setup_data(self, path: str) -> Iterable[tuple]: | ||
""" | ||
Default implementation of setup_data. | ||
|
||
Manually override if needed. | ||
""" | ||
|
||
def _get_text_value(row) -> Tuple[str, Dict[str, str]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move these into the class so they're overridable |
||
""" | ||
return the constructed text query and dict mapping text field names to | ||
values. | ||
""" | ||
# construct text query from the hf_text_fields specified | ||
text_dict = {} | ||
for col in self.hf_text_fields: | ||
text_part = row.get(col) | ||
if text_part is None: | ||
raise KeyError(f'Feature "{col}" not found in data.') | ||
text_dict[col] = text_part | ||
return '\n'.join(text_dict.values()), text_dict | ||
|
||
def _get_label_value(row): | ||
return row[self.hf_label_field] | ||
|
||
def _get_label_candidates(row, label) -> str: | ||
pre_candidates = dataset.features[self.hf_label_field].names | ||
# construct label and candidates | ||
if type(label) is int: | ||
return pre_candidates[label], pre_candidates | ||
if label in row: | ||
return row[label], [row[l] for l in pre_candidates] | ||
return label, pre_candidates | ||
|
||
# load dataset from HuggingFace | ||
dataset = load_dataset( | ||
path=self.hf_path, name=self.hf_name, split=self.hf_split | ||
) | ||
|
||
for row in dataset: | ||
query, text_dict = _get_text_value(row) | ||
label = _get_label_value(row) | ||
label, candidates = _get_label_candidates(row, label) | ||
|
||
episode_dict = text_dict | ||
episode_dict['text'] = query | ||
episode_dict['label'] = label | ||
episode_dict['label_candidates'] = candidates | ||
yield episode_dict, True | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Carve out a new |
||
class DefaultTeacher(AbstractHuggingFaceTeacher): | ||
def __init__(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use a standard constructor |
||
raise NotImplementedError( | ||
"There is no default teacher for HuggingFace datasets. Please use a specific one." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like us to add the rest of GLUE, but that can be the next PR if you want