-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BC-breaking] Split raw sequence tagging datasets into individual fil…
…es (#1176)
- Loading branch information
Showing
4 changed files
with
157 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,56 @@ | ||
import importlib | ||
from .ag_news import AG_NEWS | ||
from .sogounews import SogouNews | ||
from .dbpedia import DBpedia | ||
from .yelpreviewpolarity import YelpReviewPolarity | ||
from .yelpreviewfull import YelpReviewFull | ||
from .yahooanswers import YahooAnswers | ||
from .amazonreviewpolarity import AmazonReviewPolarity | ||
from .amazonreviewfull import AmazonReviewFull | ||
from .amazonreviewpolarity import AmazonReviewPolarity | ||
from .conll2000chunking import CoNLL2000Chunking | ||
from .dbpedia import DBpedia | ||
from .imdb import IMDB | ||
|
||
from .wikitext2 import WikiText2 | ||
from .wikitext103 import WikiText103 | ||
from .iwslt import IWSLT | ||
from .multi30k import Multi30k | ||
from .penntreebank import PennTreebank | ||
from .wmtnewscrawl import WMTNewsCrawl | ||
|
||
from .sogounews import SogouNews | ||
from .squad1 import SQuAD1 | ||
from .squad2 import SQuAD2 | ||
|
||
from .sequence_tagging import UDPOS, CoNLL2000Chunking | ||
|
||
from .multi30k import Multi30k | ||
from .iwslt import IWSLT | ||
from .udpos import UDPOS | ||
from .wikitext103 import WikiText103 | ||
from .wikitext2 import WikiText2 | ||
from .wmt14 import WMT14 | ||
from .wmtnewscrawl import WMTNewsCrawl | ||
from .yahooanswers import YahooAnswers | ||
from .yelpreviewfull import YelpReviewFull | ||
from .yelpreviewpolarity import YelpReviewPolarity | ||
|
||
DATASETS = {'IMDB': IMDB, | ||
'AG_NEWS': AG_NEWS, | ||
'SogouNews': SogouNews, | ||
'DBpedia': DBpedia, | ||
'YelpReviewPolarity': YelpReviewPolarity, | ||
'YelpReviewFull': YelpReviewFull, | ||
'YahooAnswers': YahooAnswers, | ||
'AmazonReviewPolarity': AmazonReviewPolarity, | ||
'AmazonReviewFull': AmazonReviewFull, | ||
'UDPOS': UDPOS, | ||
'CoNLL2000Chunking': CoNLL2000Chunking, | ||
'Multi30k': Multi30k, | ||
'IWSLT': IWSLT, | ||
'WMT14': WMT14, | ||
'WikiText2': WikiText2, | ||
'WikiText103': WikiText103, | ||
'PennTreebank': PennTreebank, | ||
'WMTNewsCrawl': WMTNewsCrawl, | ||
'SQuAD1': SQuAD1, | ||
'SQuAD2': SQuAD2} | ||
DATASETS = { | ||
'AG_NEWS': AG_NEWS, | ||
'AmazonReviewFull': AmazonReviewFull, | ||
'AmazonReviewPolarity': AmazonReviewPolarity, | ||
'CoNLL2000Chunking': CoNLL2000Chunking, | ||
'DBpedia': DBpedia, | ||
'IMDB': IMDB, | ||
'IWSLT': IWSLT, | ||
'Multi30k': Multi30k, | ||
'PennTreebank': PennTreebank, | ||
'SQuAD1': SQuAD1, | ||
'SQuAD2': SQuAD2, | ||
'SogouNews': SogouNews, | ||
'UDPOS': UDPOS, | ||
'WMT14': WMT14, | ||
'WMTNewsCrawl': WMTNewsCrawl, | ||
'WikiText103': WikiText103, | ||
'WikiText2': WikiText2, | ||
'YahooAnswers': YahooAnswers, | ||
'YelpReviewFull': YelpReviewFull, | ||
'YelpReviewPolarity': YelpReviewPolarity | ||
} | ||
|
||
URLS = {} | ||
NUM_LINES = {} | ||
MD5 = {} | ||
for dataset in ["AG_NEWS", | ||
"SogouNews", | ||
"DBpedia", | ||
"YelpReviewPolarity", | ||
"YelpReviewFull", | ||
"YahooAnswers", | ||
"AmazonReviewPolarity", | ||
"AmazonReviewFull", | ||
"IMDB", | ||
"WikiText2", | ||
"WikiText103", | ||
"PennTreebank", | ||
"WMTNewsCrawl", | ||
"SQuAD1", | ||
"Multi30k", | ||
"IWSLT", | ||
"WMT14", | ||
"SQuAD2"]: | ||
for dataset in DATASETS: | ||
dataset_module_path = "torchtext.experimental.datasets.raw." + dataset.lower() | ||
dataset_module = importlib.import_module(dataset_module_path) | ||
URLS[dataset] = dataset_module.URL | ||
NUM_LINES[dataset] = dataset_module.NUM_LINES | ||
MD5[dataset] = dataset_module.MD5 | ||
|
||
from .sequence_tagging import URLS as sequence_tagging_URLS | ||
|
||
URLS.update(sequence_tagging_URLS) | ||
|
||
from .sequence_tagging import NUM_LINES as sequence_tagging_NUM_LINES | ||
|
||
NUM_LINES.update(sequence_tagging_NUM_LINES) | ||
|
||
from .sequence_tagging import MD5 as sequence_tagging_MD5 | ||
|
||
MD5.update(sequence_tagging_MD5) | ||
|
||
__all__ = sorted(list(map(str, DATASETS.keys()))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from torchtext.utils import download_from_url, extract_archive | ||
from torchtext.experimental.datasets.raw.common import RawTextIterableDataset | ||
from torchtext.experimental.datasets.raw.common import wrap_split_argument | ||
from torchtext.experimental.datasets.raw.common import add_docstring_header | ||
|
||
URL = { | ||
'train': "https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz", | ||
'test': "https://www.clips.uantwerpen.be/conll2000/chunking/test.txt.gz", | ||
} | ||
|
||
MD5 = { | ||
'train': "6969c2903a1f19a83569db643e43dcc8", | ||
'test': "a916e1c2d83eb3004b38fc6fcd628939", | ||
} | ||
|
||
NUM_LINES = { | ||
'train': 8936, | ||
'test': 2012, | ||
} | ||
|
||
|
||
def _create_data_from_iob(data_path, separator="\t"): | ||
with open(data_path, encoding="utf-8") as input_file: | ||
columns = [] | ||
for line in input_file: | ||
line = line.strip() | ||
if line == "": | ||
if columns: | ||
yield columns | ||
columns = [] | ||
else: | ||
for i, column in enumerate(line.split(separator)): | ||
if len(columns) < i + 1: | ||
columns.append([]) | ||
columns[i].append(column) | ||
if len(columns) > 0: | ||
yield columns | ||
|
||
|
||
def _construct_filepath(paths, file_suffix): | ||
if file_suffix: | ||
path = None | ||
for p in paths: | ||
path = p if p.endswith(file_suffix) else path | ||
return path | ||
return None | ||
|
||
|
||
@wrap_split_argument | ||
@add_docstring_header() | ||
def CoNLL2000Chunking(root='.data', split=('train', 'test'), offset=0): | ||
extracted_files = [] | ||
for name, item in URL.items(): | ||
dataset_tar = download_from_url(item, root=root, hash_value=MD5[name], hash_type='md5') | ||
extracted_files.extend(extract_archive(dataset_tar)) | ||
|
||
data_filenames = { | ||
"train": _construct_filepath(extracted_files, "train.txt"), | ||
"valid": _construct_filepath(extracted_files, "dev.txt"), | ||
"test": _construct_filepath(extracted_files, "test.txt") | ||
} | ||
return [RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[item], | ||
_create_data_from_iob(data_filenames[item], " "), offset=offset) | ||
if data_filenames[item] is not None else None for item in split] |
102 changes: 0 additions & 102 deletions
102
torchtext/experimental/datasets/raw/sequence_tagging.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from torchtext.utils import download_from_url, extract_archive | ||
from torchtext.experimental.datasets.raw.common import RawTextIterableDataset | ||
from torchtext.experimental.datasets.raw.common import wrap_split_argument | ||
from torchtext.experimental.datasets.raw.common import add_docstring_header | ||
|
||
URL = 'https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip' | ||
|
||
MD5 = 'bdcac7c52d934656bae1699541424545' | ||
|
||
NUM_LINES = { | ||
'train': 12543, | ||
'valid': 2002, | ||
'test': 2077, | ||
} | ||
|
||
|
||
def _create_data_from_iob(data_path, separator="\t"): | ||
with open(data_path, encoding="utf-8") as input_file: | ||
columns = [] | ||
for line in input_file: | ||
line = line.strip() | ||
if line == "": | ||
if columns: | ||
yield columns | ||
columns = [] | ||
else: | ||
for i, column in enumerate(line.split(separator)): | ||
if len(columns) < i + 1: | ||
columns.append([]) | ||
columns[i].append(column) | ||
if len(columns) > 0: | ||
yield columns | ||
|
||
|
||
def _construct_filepath(paths, file_suffix): | ||
if file_suffix: | ||
path = None | ||
for p in paths: | ||
path = p if p.endswith(file_suffix) else path | ||
return path | ||
return None | ||
|
||
|
||
@wrap_split_argument | ||
@add_docstring_header() | ||
def UDPOS(root='.data', split=('train', 'valid', 'test'), offset=0): | ||
dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') | ||
extracted_files = extract_archive(dataset_tar) | ||
|
||
data_filenames = { | ||
"train": _construct_filepath(extracted_files, "train.txt"), | ||
"valid": _construct_filepath(extracted_files, "dev.txt"), | ||
"test": _construct_filepath(extracted_files, "test.txt") | ||
} | ||
return [RawTextIterableDataset("UDPOS", NUM_LINES[item], | ||
_create_data_from_iob(data_filenames[item]), offset=offset) | ||
if data_filenames[item] is not None else None for item in split] |