Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add __next__ method to RawTextIterableDataset #1141

Merged
merged 11 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 32 additions & 19 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def test_wikitext2(self):

# Add test for the subset of the standard datasets
train_iter, valid_iter, test_iter = torchtext.experimental.datasets.raw.WikiText2(split=('train', 'valid', 'test'))
self._helper_test_func(len(train_iter), 36718, next(iter(train_iter)), ' \n')
self._helper_test_func(len(valid_iter), 3760, next(iter(valid_iter)), ' \n')
self._helper_test_func(len(test_iter), 4358, next(iter(test_iter)), ' \n')
self._helper_test_func(len(train_iter), 36718, next(train_iter), ' \n')
self._helper_test_func(len(valid_iter), 3760, next(valid_iter), ' \n')
self._helper_test_func(len(test_iter), 4358, next(test_iter), ' \n')
del train_iter, valid_iter, test_iter
train_dataset, test_dataset = WikiText2(split=('train', 'test'))
train_data = torch.cat(tuple(filter(lambda t: t.numel() > 0, train_dataset)))
Expand Down Expand Up @@ -113,8 +113,8 @@ def test_penntreebank(self):
self._helper_test_func(len(test_data), 82114, test_data[30:35],
[397, 93, 4, 16, 7])
train_iter, test_iter = torchtext.experimental.datasets.raw.PennTreebank(split=('train', 'test'))
self._helper_test_func(len(train_iter), 42068, next(iter(train_iter))[:15], ' aer banknote b')
self._helper_test_func(len(test_iter), 3761, next(iter(test_iter))[:25], " no it was n't black mond")
self._helper_test_func(len(train_iter), 42068, next(train_iter)[:15], ' aer banknote b')
self._helper_test_func(len(test_iter), 3761, next(test_iter)[:25], " no it was n't black mond")
del train_iter, test_iter

def test_text_classification(self):
Expand All @@ -134,8 +134,8 @@ def test_text_classification(self):
self._helper_test_func(len(train_dataset), 120000, train_dataset[-1][1][:10],
[2155, 223, 2405, 30, 3010, 2204, 54, 3603, 4930, 2405])
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS()
self._helper_test_func(len(train_iter), 120000, next(iter(train_iter))[1][:25], 'Wall St. Bears Claw Back ')
self._helper_test_func(len(test_iter), 7600, next(iter(test_iter))[1][:25], 'Fears for T N pension aft')
self._helper_test_func(len(train_iter), 120000, next(train_iter)[1][:25], 'Wall St. Bears Claw Back ')
self._helper_test_func(len(test_iter), 7600, next(test_iter)[1][:25], 'Fears for T N pension aft')
del train_iter, test_iter

def test_num_lines_of_setup_iter_dataset(self):
Expand All @@ -144,6 +144,19 @@ def test_num_lines_of_setup_iter_dataset(self):
_data = [item for item in train_iter]
self.assertEqual(len(_data), 100)

def test_next_method_dataset(self):
train_iter, test_iter = torchtext.experimental.datasets.raw.AG_NEWS()
for_count = 0
next_count = 0
for line in train_iter:
for_count += 1
try:
next(train_iter)
next_count += 1
except:
break
self.assertEqual((for_count, next_count), (60000, 60000))

def test_imdb(self):
from torchtext.experimental.datasets import IMDB
from torchtext.vocab import Vocab
Expand All @@ -164,8 +177,8 @@ def test_imdb(self):
self._helper_test_func(len(train_dataset), 25000, train_dataset[0][1][:10],
[13, 1568, 13, 246, 35468, 43, 64, 398, 1135, 92])
train_iter, test_iter = torchtext.experimental.datasets.raw.IMDB()
self._helper_test_func(len(train_iter), 25000, next(iter(train_iter))[1][:25], 'I rented I AM CURIOUS-YEL')
self._helper_test_func(len(test_iter), 25000, next(iter(test_iter))[1][:25], 'I love sci-fi and am will')
self._helper_test_func(len(train_iter), 25000, next(train_iter)[1][:25], 'I rented I AM CURIOUS-YEL')
self._helper_test_func(len(test_iter), 25000, next(test_iter)[1][:25], 'I love sci-fi and am will')
del train_iter, test_iter

def test_iwslt(self):
Expand Down Expand Up @@ -241,10 +254,10 @@ def test_multi30k(self):

# Add test for the subset of the standard datasets
train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(split=('train', 'valid'))
self._helper_test_func(len(train_iter), 29000, ' '.join(next(iter(train_iter))),
self._helper_test_func(len(train_iter), 29000, ' '.join(next(train_iter)),
' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n',
'Two young, White males are outside near many bushes.\n']))
self._helper_test_func(len(valid_iter), 1014, ' '.join(next(iter(valid_iter))),
self._helper_test_func(len(valid_iter), 1014, ' '.join(next(valid_iter)),
' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen\n',
'A group of men are loading cotton onto a truck\n']))
del train_iter, valid_iter
Expand Down Expand Up @@ -316,9 +329,9 @@ def test_udpos_sequence_tagging(self):
([262, 16, 5728, 45, 289, 701, 1160, 4436, 10660, 585],
[6, 20, 8, 10, 8, 8, 24, 13, 8, 15]))
train_iter, valid_iter = torchtext.experimental.datasets.raw.UDPOS(split=('train', 'valid'))
self._helper_test_func(len(train_iter), 12543, ' '.join(next(iter(train_iter))[0][:5]),
self._helper_test_func(len(train_iter), 12543, ' '.join(next(train_iter)[0][:5]),
' '.join(['Al', '-', 'Zaman', ':', 'American']))
self._helper_test_func(len(valid_iter), 2002, ' '.join(next(iter(valid_iter))[0][:5]),
self._helper_test_func(len(valid_iter), 2002, ' '.join(next(valid_iter)[0][:5]),
' '.join(['From', 'the', 'AP', 'comes', 'this']))
del train_iter, valid_iter

Expand Down Expand Up @@ -369,9 +382,9 @@ def test_conll_sequence_tagging(self):
[18, 17, 12, 19, 10, 6, 3, 3, 4, 4],
[3, 5, 7, 7, 3, 2, 6, 6, 3, 2]))
train_iter, test_iter = torchtext.experimental.datasets.raw.CoNLL2000Chunking()
self._helper_test_func(len(train_iter), 8936, ' '.join(next(iter(train_iter))[0][:5]),
self._helper_test_func(len(train_iter), 8936, ' '.join(next(train_iter)[0][:5]),
' '.join(['Confidence', 'in', 'the', 'pound', 'is']))
self._helper_test_func(len(test_iter), 2012, ' '.join(next(iter(test_iter))[0][:5]),
self._helper_test_func(len(test_iter), 2012, ' '.join(next(test_iter)[0][:5]),
' '.join(['Rockwell', 'International', 'Corp.', "'s", 'Tulsa']))
del train_iter, test_iter

Expand All @@ -398,9 +411,9 @@ def test_squad1(self):
self._helper_test_func(len(train_dataset), 87599, (question[:5], ans_pos[0]),
([7, 24, 86, 52, 2], [72, 72]))
train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD1()
self._helper_test_func(len(train_iter), 87599, next(iter(train_iter))[0][:50],
self._helper_test_func(len(train_iter), 87599, next(train_iter)[0][:50],
'Architecturally, the school has a Catholic charact')
self._helper_test_func(len(dev_iter), 10570, next(iter(dev_iter))[0][:50],
self._helper_test_func(len(dev_iter), 10570, next(dev_iter)[0][:50],
'Super Bowl 50 was an American football game to det')
del train_iter, dev_iter

Expand All @@ -427,8 +440,8 @@ def test_squad2(self):
self._helper_test_func(len(train_dataset), 130319, (question[:5], ans_pos[0]),
([84, 50, 1421, 12, 5439], [9, 9]))
train_iter, dev_iter = torchtext.experimental.datasets.raw.SQuAD2()
self._helper_test_func(len(train_iter), 130319, next(iter(train_iter))[0][:50],
self._helper_test_func(len(train_iter), 130319, next(train_iter)[0][:50],
'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-Y')
self._helper_test_func(len(dev_iter), 11873, next(iter(dev_iter))[0][:50],
self._helper_test_func(len(dev_iter), 11873, next(dev_iter)[0][:50],
'The Normans (Norman: Nourmands; French: Normands; ')
del train_iter, dev_iter
4 changes: 4 additions & 0 deletions torchtext/experimental/datasets/raw/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __iter__(self):
break
yield item

def __next__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One concern I have with this approach is the following:

What happens if you do this:

for line in dataset:
    print(line)
    print(next(dataset))

The counter on line 40 for i, item in enumerate(self._iterator): is checking self.start and self.num_lines, but self._iterator will have been forwarded by the call to next. I expect that above loop will run twice as much as expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so this holds true then. next is meant to deplete the iterator, but instead the loop runs for the entirety of the dataset.

>>> lst = [1, 3, 2, 4, 5]
>>> lst_iter = iter(lst)
>>> next(lst_iter)
1
>>> next(lst_iter)
3
>>> next(lst_iter)
2
>>> next(lst_iter)
4
>>> next(lst_iter)
5
>>> next(lst_iter)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

Copy link
Contributor Author

@zhangguanheng66 zhangguanheng66 Feb 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I tried the similar thing (also update the test case)

In [8]: from torchtext.experimental.datasets.raw import AG_NEWS
   ...: train_iter, test_iter = AG_NEWS(split=('train', 'test'))
   ...: for i, item in enumerate(train_iter):
   ...:     print("iter ->", item)
   ...:     print("next ->", next(train_iter))
   ...:     if i > 1:
   ...:       break
   ...:
iter -> (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")
next -> (3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.')
iter -> (3, "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.")
next -> (3, 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.')
iter -> (3, 'Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.')
next -> (3, 'Stocks End Up, But Near Year Lows (Reuters) Reuters - Stocks ended slightly higher on Friday\\but stayed near lows for the year as oil prices surged past  #36;46\\a barrel, offsetting a positive outlook from computer maker\\Dell Inc. (DELL.O)')

The original text file is

     1 "3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics,        are seeing green again."
     2 "3","Carlyle Looks Toward Commercial Aerospace (Reuters)","Reuters - Private investment firm Carlyle Group,\which has a reputati       on for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another pa       rt of the market."
     3 "3","Oil and Economy Cloud Stocks' Outlook (Reuters)","Reuters - Soaring crude prices plus worries\about the economy and the out       look for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums."
     4 "3","Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)","Reuters - Authorities have halted oil export\flows from the        main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Sa       turday."
     5 "3","Oil prices soar to all-time record, posing new menace to US economy (AFP)","AFP - Tearaway world oil prices, toppling recor       ds and straining wallets, present a new economic menace barely three months before the US presidential elections."
     6 "3","Stocks End Up, But Near Year Lows (Reuters)","Reuters - Stocks ended slightly higher on Friday\but stayed near lows for the        year as oil prices surged past  #36;46\a barrel, offsetting a positive outlook from computer maker\Dell Inc. (DELL.O)"

item = self._iterator.__next__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just use next to call this.

return item

def __len__(self):
if self.has_setup:
return self.num_lines
Expand Down