Skip to content

Commit

Permalink
Add offset arg in the raw text dataset (#1145)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangguanheng66 authored Feb 9, 2021
1 parent 3de3fcf commit 4e295e4
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 72 deletions.
15 changes: 11 additions & 4 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,18 @@ def test_text_classification(self):
self._helper_test_func(len(test_iter), 7600, next(iter(test_iter))[1][:25], 'Fears for T N pension aft')
del train_iter, test_iter

def test_num_lines_of_setup_iter_dataset(self):
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS()
train_iter.setup_iter(start=10, num_lines=100)
def test_num_lines_of_dataset(self):
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(offset=10)
_data = [item for item in train_iter]
self.assertEqual(len(_data), 100)
self.assertEqual(len(_data), 119990)

def test_offset_dataset(self):
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS(split=('train', 'test'),
offset=10)
container = [text[:20] for idx, (label, text) in enumerate(train_iter) if idx < 5]
self.assertEqual(container, ['Oil and Economy Clou', 'No Need for OPEC to ',
'Non-OPEC Nations Sho', 'Google IPO Auction O',
'Dollar Falls Broadly'])

def test_imdb(self):
from torchtext.experimental.datasets import IMDB
Expand Down
22 changes: 4 additions & 18 deletions torchtext/experimental/datasets/raw/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,17 @@ class RawTextIterableDataset(torch.utils.data.IterableDataset):
"""Defines an abstraction for raw text iterable datasets.
"""

def __init__(self, name, full_num_lines, iterator):
def __init__(self, name, full_num_lines, iterator, offset=0):
"""Initiate text-classification dataset.
"""
super(RawTextIterableDataset, self).__init__()
self.name = name
self.full_num_lines = full_num_lines
self._iterator = iterator
self.has_setup = False
self.start = 0
self.num_lines = None

def setup_iter(self, start=0, num_lines=None):
self.start = start
self.num_lines = num_lines
if num_lines and self.start + self.num_lines > self.full_num_lines:
raise ValueError("Requested start {} and num_lines {} exceeds available number of lines {}".format(
self.start, self.num_lines, self.full_num_lines))
self.has_setup = True
self.start = offset
self.num_lines = full_num_lines - offset

def __iter__(self):
if not self.has_setup:
self.setup_iter()

for i, item in enumerate(self._iterator):
if i < self.start:
continue
Expand All @@ -45,9 +33,7 @@ def __iter__(self):
yield item

def __len__(self):
if self.has_setup:
return self.num_lines
return self.full_num_lines
return self.num_lines

def get_iterator(self):
return self._iterator
24 changes: 14 additions & 10 deletions torchtext/experimental/datasets/raw/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
}


def _setup_datasets(dataset_name, root, split, year, language):
def _setup_datasets(dataset_name, root, split, year, language, offset):
split = check_default_set(split, ('train', 'test', 'valid'))
if isinstance(split, str):
split = [split]
Expand Down Expand Up @@ -55,10 +55,10 @@ def _setup_datasets(dataset_name, root, split, year, language):
data[item] = iter(io.open(_path[item], encoding="utf8"))

return tuple(RawTextIterableDataset(dataset_name,
NUM_LINES[dataset_name][item], data[item]) for item in split)
NUM_LINES[dataset_name][item], data[item], offset=offset) for item in split)


def WikiText2(root='.data', split=('train', 'valid', 'test')):
def WikiText2(root='.data', split=('train', 'valid', 'test'), offset=0):
""" Defines WikiText2 datasets.
Create language modeling dataset: WikiText2
Expand All @@ -72,6 +72,7 @@ def WikiText2(root='.data', split=('train', 'valid', 'test')):
just a string 'train'. If 'train' is not in the tuple or string, a vocab
object should be provided which will be used to process valid and/or test
data.
offset: the number of the starting line. Default: 0
Examples:
>>> from torchtext.experimental.raw.datasets import WikiText2
Expand All @@ -80,10 +81,10 @@ def WikiText2(root='.data', split=('train', 'valid', 'test')):
"""

return _setup_datasets("WikiText2", root, split, None, None)
return _setup_datasets("WikiText2", root, split, None, None, offset)


def WikiText103(root='.data', split=('train', 'valid', 'test')):
def WikiText103(root='.data', split=('train', 'valid', 'test'), offset=0):
""" Defines WikiText103 datasets.
Create language modeling dataset: WikiText103
Expand All @@ -96,17 +97,18 @@ def WikiText103(root='.data', split=('train', 'valid', 'test')):
could also choose any one or two of them, for example ('train', 'test').
If 'train' is not in the tuple, an vocab object should be provided which will
be used to process valid and/or test data.
offset: the number of the starting line. Default: 0
Examples:
>>> from torchtext.experimental.datasets.raw import WikiText103
>>> train_dataset, valid_dataset, test_dataset = WikiText103()
>>> valid_dataset, = WikiText103(split='valid')
"""

return _setup_datasets("WikiText103", root, split, None, None)
return _setup_datasets("WikiText103", root, split, None, None, offset)


def PennTreebank(root='.data', split=('train', 'valid', 'test')):
def PennTreebank(root='.data', split=('train', 'valid', 'test'), offset=0):
""" Defines PennTreebank datasets.
Create language modeling dataset: PennTreebank
Expand All @@ -121,6 +123,7 @@ def PennTreebank(root='.data', split=('train', 'valid', 'test')):
just a string 'train'. If 'train' is not in the tuple or string, a vocab
object should be provided which will be used to process valid and/or test
data.
offset: the number of the starting line. Default: 0
Examples:
>>> from torchtext.experimental.datasets.raw import PennTreebank
Expand All @@ -129,10 +132,10 @@ def PennTreebank(root='.data', split=('train', 'valid', 'test')):
"""

return _setup_datasets("PennTreebank", root, split, None, None)
return _setup_datasets("PennTreebank", root, split, None, None, offset)


def WMTNewsCrawl(root='.data', split=('train'), year=2010, language='en'):
def WMTNewsCrawl(root='.data', split=('train'), year=2010, language='en', offset=0):
""" Defines WMT News Crawl.
Create language modeling dataset: WMTNewsCrawl
Expand All @@ -143,11 +146,12 @@ def WMTNewsCrawl(root='.data', split=('train'), year=2010, language='en'):
(Default: 'train')
year: the year of the dataset (Default: 2010)
language: the language of the dataset (Default: 'en')
offset: the number of the starting line. Default: 0
Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test.
"""

return _setup_datasets("WMTNewsCrawl", root, split, year, language)
return _setup_datasets("WMTNewsCrawl", root, split, year, language, offset)


DATASETS = {
Expand Down
14 changes: 8 additions & 6 deletions torchtext/experimental/datasets/raw/question_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def _create_data_from_json(data_path):
yield (_context, _question, _answers, _answer_start)


def _setup_datasets(dataset_name, root, split):
def _setup_datasets(dataset_name, root, split, offset):
split = check_default_set(split, ('train', 'dev'))
extracted_files = {key: download_from_url(URLS[dataset_name][key], root=root,
hash_value=MD5[dataset_name][key], hash_type='md5') for key in split}
return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item],
_create_data_from_json(extracted_files[item])) for item in split)
_create_data_from_json(extracted_files[item]), offset=offset) for item in split)


def SQuAD1(root='.data', split=('train', 'dev')):
def SQuAD1(root='.data', split=('train', 'dev'), offset=0):
""" A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD1.0.
The iterator yields a tuple of (raw context, raw question, a list of raw answer,
a list of answer positions in the raw context).
Expand All @@ -51,17 +51,18 @@ def SQuAD1(root='.data', split=('train', 'dev')):
split: a string or tuple for the returned datasets (Default: ('train', 'dev'))
By default, both datasets (train, dev) are generated. Users could also choose any one or two of them,
for example ('train', 'dev') or just a string 'train'.
offset: the number of the starting line. Default: 0
Examples:
>>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD1()
>>> for idx, (context, question, answer, ans_pos) in enumerate(train_dataset):
>>> print(idx, (context, question, answer, ans_pos))
"""

return _setup_datasets("SQuAD1", root, split)
return _setup_datasets("SQuAD1", root, split, offset)


def SQuAD2(root='.data', split=('train', 'dev')):
def SQuAD2(root='.data', split=('train', 'dev'), offset=0):
""" A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD2.0.
The iterator yields a tuple of (raw context, raw question, a list of raw answer,
a list of answer positions in the raw context).
Expand All @@ -75,14 +76,15 @@ def SQuAD2(root='.data', split=('train', 'dev')):
split: a string or tuple for the returned datasets (Default: ('train', 'dev'))
By default, both datasets (train, dev) are generated. Users could also choose any one or two of them,
for example ('train', 'dev') or just a string 'train'.
offset: the number of the starting line. Default: 0
Examples:
>>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD2()
>>> for idx, (context, question, answer, ans_pos) in enumerate(train_dataset):
>>> print(idx, (context, question, answer, ans_pos))
"""

return _setup_datasets("SQuAD2", root, split)
return _setup_datasets("SQuAD2", root, split, offset)


DATASETS = {
Expand Down
14 changes: 8 additions & 6 deletions torchtext/experimental/datasets/raw/sequence_tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _construct_filepath(paths, file_suffix):
return None


def _setup_datasets(dataset_name, separator, root, split):
def _setup_datasets(dataset_name, separator, root, split, offset):
split = check_default_set(split, target_select=('train', 'valid', 'test'))
extracted_files = []
if isinstance(URLS[dataset_name], dict):
Expand All @@ -60,11 +60,11 @@ def _setup_datasets(dataset_name, separator, root, split):
"test": _construct_filepath(extracted_files, "test.txt")
}
return tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item],
_create_data_from_iob(data_filenames[item], separator))
_create_data_from_iob(data_filenames[item], separator), offset=offset)
if data_filenames[item] is not None else None for item in split)


def UDPOS(root=".data", split=('train', 'valid', 'test')):
def UDPOS(root=".data", split=('train', 'valid', 'test'), offset=0):
""" Universal Dependencies English Web Treebank
Separately returns the training and test dataset
Expand All @@ -75,15 +75,16 @@ def UDPOS(root=".data", split=('train', 'valid', 'test')):
By default, all the datasets (train, valid, test) are generated.
Users could also choose any one or two of them,
for example ('train', 'valid', 'test') or just a string 'train'.
offset: the number of the starting line. Default: 0
Examples:
>>> from torchtext.experimental.datasets.raw import UDPOS
>>> train_dataset, valid_dataset, test_dataset = UDPOS()
"""
return _setup_datasets("UDPOS", "\t", root, split)
return _setup_datasets("UDPOS", "\t", root, split, offset)


def CoNLL2000Chunking(root=".data", split=('train', 'test')):
def CoNLL2000Chunking(root=".data", split=('train', 'test'), offset=0):
""" CoNLL 2000 Chunking Dataset
Separately returns the training and test dataset
Expand All @@ -93,12 +94,13 @@ def CoNLL2000Chunking(root=".data", split=('train', 'test')):
split: a string or tuple for the returned datasets (Default: ('train', 'test'))
By default, both datasets (train, test) are generated. Users could also choose any one or two of them,
for example ('train', 'test') or just a string 'train'.
offset: the number of the starting line. Default: 0
Examples:
>>> from torchtext.experimental.datasets.raw import CoNLL2000Chunking
>>> train_dataset, test_dataset = CoNLL2000Chunking()
"""
return _setup_datasets("CoNLL2000Chunking", " ", root, split)
return _setup_datasets("CoNLL2000Chunking", " ", root, split, offset)


DATASETS = {
Expand Down
Loading

0 comments on commit 4e295e4

Please sign in to comment.