-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Text classification datasets with new torchtext dataset abstraction #701
Text classification datasets with new torchtext dataset abstraction #701
Conversation
ba26975
to
9d8aa61
Compare
Tokenizer + vocab
|
return data | ||
|
||
|
||
def build_vocab(dataset, transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this take a kwargs argument that will pass arguments to the vocab constructor inside build_vocab_from_iterator
? This will allow us to do something like:
train, test = AG_NEWS()
transform1 = TokenizerTransform('basic_english')
train, valid = torch.utils.data.random_split(train, [90_000, 10_000]) #not exact numbers
vocab = build_vocab(train, transform1, max_size = 25_000, min_freq = 2)
It will also mean build_vocab_from_iterator
will also need to be modified to accept kwargs too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be added. However, the wrap-up to build vocab
is pretty simple and users now have the flexibility to do that themselves.
import torchtext
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.transforms import TokenizerTransform, VocabTransform, Compose
from torchtext.experimental.datasets.new_text_classification import build_vocab
# Import raw text strings
train, test = AG_NEWS()
# Build tokenizer transform
transform1 = TokenizerTransform('basic_english')
class PROCESSED(AG_NEWS): # Wrap as map-like function
def __getitem__(self, n):
item = super().__getitem__(n)
return transform1(item)
def __next__(self):
item = super().__next__()
return transform1(item)
transformed_train = PROCESSED()
# Build vocab transform
vocab = VocabTransform(transformed_train)
# A new dataset with raw text strings + label/string transforms
from torchtext.experimental.datasets.new_text_classification import TextClassificationDataset
new_train = TextClassificationDataset(train.data, [int, transform1, vocab]) # transforms can be wrapped similarly |
How about import torchtext
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.transforms import TokenizerTransform, VocabTransform, Compose
from torchtext.experimental.datasets.new_text_classification import build_vocab
# Import raw text strings
train, _ = AG_NEWS()
def tokenizer(raw_text):
splits = raw_text.split
label, row = int(splits[0]), splits[1:]
return label, TokenizerTransform('basic_english')(row) #Assuming this has 0 init cost
tokenized_train = map(tokenizer, train) # For lack of a better function name
# Build vocab transform
# EDIT: Needs function to have tokenized_train to only return text part
vocab = build_vocab(map(lambda x, y: y, tokenized_train))
# Assuming datasets don't need to be reset after consumed
# EDIT: Need to take care of label
new_train = map(lambda x, y: x, vocab(y), tokenized_train) |
Some offline discussions:
|
@zhangguanheng66 the new APIs look great! I have a few questions:
|
import io | ||
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader | ||
|
||
URLS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In torchaudio there recently has been issues with corrupted downloads. Either within this PR or at least as a follow-up item we should look into md5 verification of downloads to make sure the user is actually getting the correct data.
|
||
train_data = _create_data_from_csv(train_csv_path) | ||
test_data = _create_data_from_csv(test_csv_path) | ||
return (RawTextDataset(train_data), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check whether the returned objects here actually carry the doc string of the calling class. I'll look into this a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just check that object.__doc__
carries the doc string of the calling class (a.k.a. RawTextDataset
in this case).
return data_set | ||
|
||
|
||
def IMDB(root='.data'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only function that accepts an explicit "root" argument.
In general, let's think about how this could be something more generic than a path. In the future users might want to pass other file like objects as a source for dataset construction.
return data | ||
|
||
|
||
class RawTextDataset(torch.utils.data.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In light of multiprocessing, data parallelism etc. etc. I wonder if it's worthwhile looking into making this an IterableDataset after all and introducing a convenience function that creates a map-style dataset by simply exhausting the iterator and writing out the result. After all this is the raw dataset, so that'd provide an even more general interface. Let's also look into JIT-ability of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to IterableDataset
for raw text dataset.
self.start = 0 | ||
self.num_lines = None | ||
|
||
def setup_iter(self, start=0, num_lines=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the expected user interface for this? Why not make this part of the constructor or a factory function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default behavior of the raw datasets is to load all the text strings and labels (no complicate APIs). If users like to cache a chunk of the data (for example the workers in DataLoader
), they have to explicitly call setup_iter
function and set up the iterator before caching. IMO, this maintains a very simple and clean API at the raw dataset level.
data_select=('train', 'test')): | ||
tokenizer=None, data_select=('train', 'test')): | ||
text_transform = [] | ||
if not tokenizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"if tokenizer is None" is more precise. This will also succeed for empty lists and bool values, which isn't the default.
from torchtext.data.utils import get_tokenizer | ||
from torchtext.vocab import build_vocab_from_iterator | ||
from torchtext.vocab import Vocab | ||
from torchtext.datasets import TextClassificationDataset | ||
from torchtext.experimental.datasets.raw import AG_NEWS as RawAG_NEWS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this should be able to just import "raw" and then index into this using the name of the dataset you're calling this from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good to go.
fix #690 |
I'm wondering if I can use the I see it inherits from IterableDataset, while TabularDataset inherits from torchtext.data.dataset.Dataset->torch.utils.data.Dataset. So currently since I'm reading my data from a CSV I have to go the old route of creating text and label fields (for a binary classification task). I'm wondering if I can replace that with these new |
|
A new dataset abstraction that decouples data and vocab/tokenizer.
To load raw text dataset:
Or wrap up everything above and load processed dataset with one-command: