Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text classification datasets with new torchtext dataset abstraction #701

Merged
merged 48 commits into from
Apr 21, 2020

Conversation

zhangguanheng66
Copy link
Contributor

@zhangguanheng66 zhangguanheng66 commented Feb 28, 2020

A new dataset abstraction that decouples data and vocab/tokenizer.

To load raw text dataset:

from torchtext.experimental.datasets import RawAG_NEWS
train, test = RawAG_NEWS()

# Process text data
from torchtext.experimental.datasets.text_classification import build_vocab
from torchtext.experimental.transforms import TokenizerTransform, VocabTransform, ToTensor
from torchtext.data.utils import get_tokenizer
from torch.nn import Sequential

vocab = build_vocab(train, TokenizerTransform(get_tokenizer('basic_english')))
text_transform = Sequential(TokenizerTransform(get_tokenizer('basic_english')), VocabTransform(vocab), ToTensor())
label_transform = ToTensor()
for (label, txt) in train[:10]:
    print(label_transform(label), text_transform(txt))

Or wrap up everything above and load processed dataset with one-command:

from torchtext.experimental.datasets import AG_NEWS
train, test = AG_NEWS()

@zhangguanheng66
Copy link
Contributor Author

@fmassa @cpuhrsch @vincentqb

@zhangguanheng66
Copy link
Contributor Author

zhangguanheng66 commented Mar 3, 2020

VocabTransform is scriptable for a Dict vocab.

import torch
import torchtext
vocab = {'here': 1, 'we': 2, 'are': 3}
vocab_transform = torchtext.experimental.datasets.text_classification.VocabTransform(vocab)
jit_method = torch.jit.script(vocab_transform)
print(vocab_transform(['here', 'we', 'are']) == jit_method(['here', 'we', 'are']))

TokenizerTransform is scriptable for a split tokenizer

import torch
import torchtext
from torchtext.data.utils import get_tokenizer
token_transform = torchtext.experimental.datasets.text_classification.TokenizerTransform(get_tokenizer(None))
token_transform('here we are')
jit_method = torch.jit.script(token_transform)
print(token_transform('here we are') == jit_method('here we are'))

Tokenizer + vocab

text_transform = torchtext.experimental.datasets.text_classification.TextSequential(token_transform, vocab_transform)
text_transform('here we are')
jit_method = torch.jit.script(text_transform)
print(text_transform('here we are') == jit_method('here we are'))

return data


def build_vocab(dataset, transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this take a kwargs argument that will pass arguments to the vocab constructor inside build_vocab_from_iterator? This will allow us to do something like:

train, test = AG_NEWS()
transform1 = TokenizerTransform('basic_english')
train, valid = torch.utils.data.random_split(train, [90_000, 10_000]) #not exact numbers
vocab = build_vocab(train, transform1, max_size = 25_000, min_freq = 2)

It will also mean build_vocab_from_iterator will also need to be modified to accept kwargs too.

Copy link
Contributor Author

@zhangguanheng66 zhangguanheng66 Mar 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be added. However, the wrap-up to build vocab is pretty simple and users now have the flexibility to do that themselves.

@vincentqb
Copy link
Contributor

vincentqb commented Mar 11, 2020

import torchtext
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.transforms import TokenizerTransform, VocabTransform, Compose
from torchtext.experimental.datasets.new_text_classification import build_vocab

# Import raw text strings
train, test = AG_NEWS()

# Build tokenizer transform
transform1 = TokenizerTransform('basic_english')

class PROCESSED(AG_NEWS):  # Wrap as map-like function

    def __getitem__(self, n):
        item = super().__getitem__(n)
        return transform1(item)

    def __next__(self):
        item = super().__next__()
        return transform1(item)

transformed_train = PROCESSED()

# Build vocab transform
vocab = VocabTransform(transformed_train)

# A new dataset with raw text strings + label/string transforms
from torchtext.experimental.datasets.new_text_classification import TextClassificationDataset
new_train = TextClassificationDataset(train.data, [int, transform1, vocab])  # transforms can be wrapped similarly

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Mar 11, 2020

How about

import torchtext
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.transforms import TokenizerTransform, VocabTransform, Compose
from torchtext.experimental.datasets.new_text_classification import build_vocab

# Import raw text strings
train, _ = AG_NEWS()

def tokenizer(raw_text):
    splits = raw_text.split
    label, row = int(splits[0]), splits[1:]
    return label, TokenizerTransform('basic_english')(row) #Assuming this has 0 init cost

tokenized_train = map(tokenizer, train) # For lack of a better function name

# Build vocab transform
# EDIT: Needs function to have tokenized_train to only return text part
vocab = build_vocab(map(lambda x, y: y, tokenized_train))

# Assuming datasets don't need to be reset after consumed
# EDIT: Need to take care of label
new_train = map(lambda x, y: x, vocab(y), tokenized_train)

@vincentqb
Copy link
Contributor

vincentqb commented Mar 11, 2020

@fmassa points to tf and lua

@zhangguanheng66
Copy link
Contributor Author

Some offline discussions:

  • Transforms are in general some callable objects and used to map the inputs according to a "contract". For example, a tokenizer transform defines the "contract" to convert a string to a list of tokens.
  • Vocab transform contains a "dictionary" object which maps tokens to ids.
  • A separate folder is created to save the raw text datasets. For the existing datasets in torchtext library, Some standard transforms and the raw text datasets are wrapped together to support "one-command" data loading.

@hudeven
Copy link
Contributor

hudeven commented Apr 6, 2020

@zhangguanheng66 the new APIs look great! I have a few questions:

  1. How to handle dense/categorical feature with the new API?
  2. As the data for language model could be too large to fit in memory. We might have to use IterableDataset. Could you also have a demo for IterableDataset as well?
  3. How to support custom batching logic and custom sampling logic for IterableDataset?

import io
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader

URLS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torchaudio there recently has been issues with corrupted downloads. Either within this PR or at least as a follow-up item we should look into md5 verification of downloads to make sure the user is actually getting the correct data.


train_data = _create_data_from_csv(train_csv_path)
test_data = _create_data_from_csv(test_csv_path)
return (RawTextDataset(train_data),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check whether the returned objects here actually carry the doc string of the calling class. I'll look into this a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just check that object.__doc__ carries the doc string of the calling class (a.k.a. RawTextDataset in this case).

return data_set


def IMDB(root='.data'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only function that accepts an explicit "root" argument.

In general, let's think about how this could be something more generic than a path. In the future users might want to pass other file like objects as a source for dataset construction.

return data


class RawTextDataset(torch.utils.data.Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In light of multiprocessing, data parallelism etc. etc. I wonder if it's worthwhile looking into making this an IterableDataset after all and introducing a convenience function that creates a map-style dataset by simply exhausting the iterator and writing out the result. After all this is the raw dataset, so that'd provide an even more general interface. Let's also look into JIT-ability of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to IterableDataset for raw text dataset.

self.start = 0
self.num_lines = None

def setup_iter(self, start=0, num_lines=None):
Copy link
Contributor

@cpuhrsch cpuhrsch Apr 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the expected user interface for this? Why not make this part of the constructor or a factory function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default behavior of the raw datasets is to load all the text strings and labels (no complicate APIs). If users like to cache a chunk of the data (for example the workers in DataLoader), they have to explicitly call setup_iter function and set up the iterator before caching. IMO, this maintains a very simple and clean API at the raw dataset level.

data_select=('train', 'test')):
tokenizer=None, data_select=('train', 'test')):
text_transform = []
if not tokenizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"if tokenizer is None" is more precise. This will also succeed for empty lists and bool values, which isn't the default.

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.vocab import Vocab
from torchtext.datasets import TextClassificationDataset
from torchtext.experimental.datasets.raw import AG_NEWS as RawAG_NEWS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this should be able to just import "raw" and then index into this using the name of the dataset you're calling this from.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good to go.

@zhangguanheng66 zhangguanheng66 merged commit 041f5a5 into pytorch:master Apr 21, 2020
@zhangguanheng66
Copy link
Contributor Author

fix #690

@thvasilo
Copy link

I'm wondering if I can use the RawTextIterableDataset with a csv input.

I see it inherits from IterableDataset, while TabularDataset inherits from torchtext.data.dataset.Dataset->torch.utils.data.Dataset.

So currently since I'm reading my data from a CSV I have to go the old route of creating text and label fields (for a binary classification task).

I'm wondering if I can replace that with these new RawTextIterableDataset, and if you have any examples of how to do that.

@zhangguanheng66
Copy link
Contributor Author

RawTextIterableDataset actually sets up the iterator to read a CSV file. Then, TextIterableDataset caches the iterator for labels and text strings. If you combine those two steps together, it will be very similar to https://github.com/pytorch/text/blob/master/torchtext/datasets/text_classification.py.
In addition to the raw text strings, you have to set up the tokenizer/vocab transforms. With the new dataset abstraction, we hope to get rid of the old utils (like fields).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants