-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overview of issues in torchtext and the plan for revamping #664
Comments
The new dataset API looks like a good start - I think it's a good idea to move away from using I'm currently experimenting on the def create_dataset(dataset, dataset_args, vocab_args):
#assert datset exists in text_classification
assert dataset in torchtext.datasets.text_classification.DATASETS
#get dataset
train_data, _ = getattr(torchtext.datasets.text_classification, dataset)(**dataset_args)
#get unfiltered vocab (default args)
old_vocab = train_data.get_vocab()
#create new filtered vocabulary (with desired args)
new_vocab = torchtext.vocab.Vocab(counter = old_vocab.freqs,
**vocab_args)
#return dataset with new vocabulary
return getattr(torchtext.datasets.text_classification, dataset)(vocab = new_vocab,
**dataset_args) An example of use: train_data, test_data = create_dataset('AG_NEWS',
dataset_args = {'ngrams': 1},
vocab_args = {'max_size': 25_000}) With 0.5, we'll be able to pass a tokenizer to the dataset - which is much needed - but as we can't pass arguments to the initialization of the The default pad and unk tokens should also be part of the Also, a better def collator(batch, pad_idx):
labels, sequences = zip(*batch)
#FloatTensor for binary classification, LongTensor for multiclass classification
labels = torch.FloatTensor(labels)
sequences = torch.nn.utils.rnn.pad_sequence(sequences,
padding_value = pad_idx)
return labels, sequences
Then this can be passed to the iterator with: import functools
pad_idx = train_data.get_vocab()['<pad>']
batch_size = 64
collate_fn = functools.partial(collator,
pad_idx = pad_idx)
train_iterator = torch.utils.data.DataLoader(train_data,
shuffle = True,
batch_size = batch_size,
collate_fn = collate_fn) The collate function above can also be expanded to provide sequence lengths and masks: def collator(batch, pad_idx):
labels, sequences = zip(*batch)
labels = torch.FloatTensor(labels)
lengths = torch.LongTensor([len(sequence) for sequence in sequences])
sequences = torch.nn.utils.rnn.pad_sequence(sequences,
padding_value = pad_idx)
masks = (sequences != pad_idx)
return labels, sequences, lengths, masks This gets back the missing Apart from those small things - I think you've done a good job! |
@bentrevett Thanks for your comments and great feedback. For padding, yes, |
Batch/iterator issue with RawField #266 |
translation datasets are too slow with |
Maybe it's worth using |
Yeap. We had discussions about this. Since all the other datasets are using |
Hi everyone! I am quite new to torchtext (but a PyTorch user for 2+ years). My first impression about the new dataset API is great, however, there is one thing I cannot accomplish in a clean manner. I am using the Penn Treebank dataset from # query training set to build vocab
vocab = PennTreebank(data_select="train")[0].get_vocab()
# create new filtered vocabulary (with word-level embeddings)
self.vocab = torchtext.vocab.Vocab(counter=vocab.freqs, vectors=torchtext.vocab.GloVe(name='6B', dim=self.embedding_dim))
# create datasets
train_dataset, valid_dataset, test_dataset = PennTreebank(vocab=self.vocab) The problem is, that the returned dataset cannot provide sequences (if a write a My solution would be to write a custom dataset and copy data from the one returned by In pseudocode, I need something like this: train_dataset, valid_dataset, test_dataset = PennTreebank(vocab=self.vocab, seq_len=32)
print(train_dataset[0].shape) # should be 32 X 1 |
Yes. That's the problem for all the three WLM datasets. One way to wrap up the dataset is here |
@zhangguanheng66: Thanks for the help! |
Is it possible to use this new-style datasets when loading from csv using a TabularDataset? |
Is this something you are looking for? #701 |
@zhangguanheng66 I think that could work, moving the conversation to #701 for clarification. |
Does anyone know how to define |
Could you open a separate issue for this and attach a code snippet? This issue is to introduce the new dataset abstraction. |
Is there an example somewhere of an NLP Dataset using the new scheme which adds '<sos>' and '<eos>' tokens to each sequence? Where would you suggest that should be done, in the dataset transformer pipeline or in the collate function? |
You can do both with the new "scheme". But I guess adding the "and" token id in the |
Thanks @zhangguanheng66 - for what it's worth I ended up doing the '<sos>' and '<eos>' tokens in the transform. Then I used @bentrevett collate_fn fro the padding. My main comment is I think you need to update
The alternative is a bit hacky - constructing a new Vocab from the return of 'build_vocab()' but using different **kwargs. |
Yup. We can accept a PR for that. |
Could also be written as
That way the only required change is forwarding the arguments form build_vocab to the contructor of Vocab, which is pretty much just adding more functionality to this factory function. I'd not want to see us merge the index and transforms portion of the suggest addition to build_vocab. |
@cpuhrsch that makes sense, but I think you left the src_tokenizer args in
Also I actually combined 'build_vocab' from Multi30k with It seems that
|
I think and also what I mean above, |
Here's some feedback after playing around with the new As for the feedback:
Happy to discuss all of these and help with any pull requests if needed. |
Yes this is what I would propose - change Line 545 in 0302ea9
I had to change the usage example slightly - as you say build_vocab_from_iterator takes an iterator of tokens (arguably
|
Also one other thing I noticed. |
|
@bentrevett Thanks Ben for your valuable feedbacks. We will address those comments in separate issues and cc you there. |
Yes, it's not a big deal. But instead of passing a list of strings and transforms, you can also pass a list of tensors and None (i.e. pre-proccess the data). I found applying transforms at train time considerably slowed down my training in the case where I was using a custom (slow) tokeniser. So I applied the transforms as the file was loaded but then I had to create my own copy of |
You don't need to use this dataset during training if you want to avoid the overhead. Starting with the raw text iterator, you can use the transforms to process the dataset and save them as a list of tensors. Then, pass the list of tensors to DataLoader. In this case, you can avoid the overhead. Another alternative is to check out the dataset and save the processed data as a list of tensors. |
Oh right I wasn't aware that DataLoader takes a list of tuples - thanks for that tip, I'll give it a go. |
Hi all, I have been playing around with the new api for a while. I am wondering is there a way to add custom 'text_transform' to input. For example, let's say I want to transform all str to lowercase or truncate the text str to a certain length. In my opinion, I think that should be passed as an argument so we can append them to the 'text_transform'. Also, I am wondering why we are still using the old torchtext.vocab instead of the new experimental vocab in the examples? Anyway, I think it's an interesting change and I am wondering is there anything I can contribute? |
Thanks for the comment. For you first question, you should check out the raw text data iterator and build a text transform pipeline. This way will give you more flexibility. For your second comment, we will switch to the new vocabulary once we are done with some cleanup. |
Thank you for the reply. I wonder whether the API support us to write a Dataset object for custom dataset? It seems to be hard to do so with the new API. Like the text classification dataset, the build vocab and the transform pipeline is written in the _setup_datasets function, which is not accessible for us if we were to build a custom text classification dataset. |
@zhangguanheng66 I believe @KwanWaiChung comment is very relevant, and we should make it very easy for users to understand how to write their own dataset. |
Here's a minimal example of how to use your own data - here given as a very small list - to create a import torch
from torchtext.experimental.datasets.text_classification import TextClassificationDataset
from torchtext.experimental.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor
# load data from whatever format it's saved in to an iterable of (label, text)
my_data = [('pos', 'this film is great'), ('neg', 'this film is bad'), ('neg', 'this film is awful')]
# tokenizer can be any callable function that goes from str -> list[str]
my_tokenizer = get_tokenizer('basic_english')
# build vocabulary from data
my_vocab = build_vocab_from_iterator([my_tokenizer(text) for label, text in my_data])
# how should the label be transformed?
# str -> int -> LongTensor
label_transforms = sequential_transforms(lambda x: 1 if x == 'pos' else 0, totensor(torch.long))
# how should the text be transformed?
# str -> list[str] -> list[int] -> LongTensor
text_transforms = sequential_transforms(my_tokenizer, vocab_func(my_vocab), totensor(torch.long))
# tuple the transforms
my_transforms = (label_transforms, text_transforms)
# create TextClassificationDataset with data, vocabulary and transforms
dataset = TextClassificationDataset(my_data, my_vocab, my_transforms) The only missing steps to apply this to actual data would be to add some code that loads your data into the list of (label, text) tuples. Any pre-processing desired can be handled by writing your own custom tokenizer function or any other functions that will fit within the |
Thanks @bentrevett for the comment and the comment explains the process very well. I will use the Language Modeling dataset as an example and explain again how it works. For the experimental datasets in torchtext, you can have two kinds
Who should use non-raw datasets? For those who want to load the processed data with the single command, you should use the LM datasets, like text/torchtext/experimental/datasets/language_modeling.py Lines 133 to 166 in b733bb1
If you want more flexibility, like "truncate the text str to a certain length" requested by @KwanWaiChung, you have to use the raw dataset with the custom text transform pipeline. So how to do that? You can treat text/torchtext/experimental/datasets/language_modeling.py Lines 63 to 64 in b733bb1
So what exactly text/torchtext/experimental/datasets/language_modeling.py Lines 81 to 82 in b733bb1
In order to have the transform pipeline, you have to obtain a tokenizer https://github.com/pytorch/text/blob/master/torchtext/experimental/datasets/language_modeling.py#L65-L66 and generate a vocabulary text/torchtext/experimental/datasets/language_modeling.py Lines 75 to 79 in b733bb1
At the end, you pass the data and transforms to the language modeling abstraction to have a map-style dataset, which works with Dataloader text/torchtext/experimental/datasets/language_modeling.py Lines 89 to 90 in b733bb1
I'm working on a more hands-on tutorial to show how to build a dataset with the ideas of building blocks. |
Thanks for the detailed example, that's really clear! I am wondering currently can we load pretrained word vectors like before? Or is it some functionality that is planned to add later? Just asking because there are some comments above talking about the issues of the experimental Vocab class, but the loading pretrained vectors is not mentioned. |
The way I've been using pre-trained vectors is by loading them and then "aligning" them with the vocabulary to create a from torchtext.experimental.vectors import GloVe
# define desired embedding dim
emb_dim = 100
# get pretrained glove vectors
glove = GloVe(name = '6B',
dim = emb_dim)
# create a tensor used for holding the pre-trained vectors for each element of the vocab
pretrained_embedding = torch.zeros(len(my_vocab), emb_dim)
# get the pretrained vector's vocab, Dict[str, int]
pretrained_vocab = glove.vectors.get_stoi()
# iterate over your vocab's `itos` attribute, a list of tokens within the vocab
# if the token is in the pre-trained vocab, i.e. if it has a pre-trained vector
# then replace its row in the pre-trained embedding tensor with the pre-trained vector
# if the token is NOT in the pre-trained vocab, we leave it initialized to zero
for idx, token in enumerate(my_vocab.get_itos()):
if token in pretrained_vocab:
pretrained_vector = glove[token] # pretrained_vector is a FloatTensor pre-trained vector for `token`
pretrained_embedding[idx] = pretrained_vector # update the appropriate row in pretrained_embedding
# at this point we have the aligned pre-trained vectors, but we need to actually use them in our model
# later on, when you've defined your model with an nn.Embedding layer called `embedding`
# replace the randomly initialized embedding with your pre-trained embedding
model.embedding.weight.data.copy_(pretrained_embedding) |
I've been trying to use def collator(batch, pad_idx):
labels, sequences = zip(*batch)
labels = torch.FloatTensor(labels)
lengths = torch.LongTensor([len(sequence) for sequence in sequences])
sequences = torch.nn.utils.rnn.pad_sequence(sequences,
padding_value = pad_idx)
masks = (sequences != pad_idx)
return labels, sequences, lengths, masks but this results in
i'm not sure what would be good way to go about doing this since there's been a lot of changes in |
@satyajitghana To avoid this error you must pass |
Motivation and summary of the current issues in torchtext
Based on the feedback from users, there are several issues existing in torchtext, including
Field
class couples tokenizer, vocabulary, split, batching and sampling, padding, and numericalization together. The currentField
class works as a "black box", and users are confused about what's going on within the class. Instead, those components should be divided into several basic building blocks. This is more consistent with PyTorch core library, which grants users the freedom to build the models and pipelines with orthogonal components.Iterator
,Batch
,splits
) should be replaced by the corresponding modules intorch.utils.data
.New datasets in
torchtext.experimental.datasets
We have re-written several datasets in
torchtext.experimental.datasets
which were using the new abstractions. The old version of the datasets are still available intorchtext.datasets
and the new datasets are opt-in.Case study for IMDB dataset
API for new datasets
To load the new datasets, simply call the dataset API, as follow:
To specify a tokenizer:
If you just need the test set (must pass a
Vocab
object!):Legacy code
The old IMDB dataset is still available in the folder
torchtext.datasets
. You can use the legacy datasets, as follow:Difference
With the old pattern, users have to create a
Field
object including a specific tokenizer. In the new dataset API, user can pass a custom tokenizer directly to the dataset constructor. A custom tokenizer defines the method to convert a string to a list of tokensIn the old dataset,
vocab
object is associated withField
class, which is not flexible enough to accept a pre-trainedvocab
object. In the new dataset, thevocab
object can be obtained byand apply to generate other new datasets.
The datasets with the new pattern return a tensor of token IDs, instead of tokens in the old pattern. If users would like to retrieve the tokens, simply use the following command:
Unlike the old pattern using
BucketIterator.splits
, users are encouraged to usetorch.utils.data.DataLoader
to generate batches of data. You can specify how to batch and pad the samples with a custom function passed tocollate_fn
. Here is an example to pad sequences with similar lengths and load data throughDataLoader
. To generate random samples, turn on theshuffle
flag inDataLoader
. Otherwise, a sequential sampler will be automatically constructed.Randomly split a dataset into non-overlapping new datasets of given lengths.
Reference:
A few recent issues from OSS users:
split
function is confusing Using Google 1-billion benchmark data on PyTorch #644vocab
object based on a subset of text file KeyError in vocab if the vocab is built on a subset of the initially read data #642vocab
object to build a dataset TEXT.build_vocab for two datasets #648torch.utils.data.DataLoader
How to prefetch data? #660The text was updated successfully, but these errors were encountered: