Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torchtext dataloaders #49

Closed
phiweger opened this issue Jun 12, 2020 · 4 comments
Closed

Support torchtext dataloaders #49

phiweger opened this issue Jun 12, 2020 · 4 comments

Comments

@phiweger
Copy link

See e.g. https://github.com/pytorch/text.

Because these classes inherit from DataLoader, I think they should work out of the box were it not for the type check.

Kind regards, and thanks for the awesome package!

@davidtvs
Copy link
Owner

Could you provide more details as to what exactly the issue is? I had a quick look into the dataloaders and datasets from that link and didn't notice anything that would make them unsupported.

@phiweger
Copy link
Author

yeah sure, so I'd do something like

train_iter, dev_iter, test_iter = data.BucketIterator.splits(
    (train, dev, test),
    batch_size=batch_size,
    # batch_sizes=(100, 100, 100),
    sort_key=lambda x: len(x.text),
    sort_within_batch=True,
    device=device)

criterion = nn.BCELoss().to(device)

from torch_lr_finder import LRFinder

optimizer = torch.optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)
lr_finder = LRFinder(model, optimizer, criterion, device="cpu")
lr_finder.range_test(train_iter, end_lr=100, num_iter=100)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-4-07322a04bd5f> in <module>
      3 optimizer = torch.optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)
      4 lr_finder = LRFinder(model, optimizer, criterion, device="cpu")
----> 5 lr_finder.range_test(train_iter, end_lr=100, num_iter=100)
      6 lr_finder.plot() # to inspect the loss-learning rate graph
      7 lr_finder.reset() # to reset the model and optimizer to their initial state

.../python3.7/site-packages/torch_lr_finder/lr_finder.py in range_test(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)
    264                 "`train_loader` has unsupported type: {}."
    265                 "Expected types are `torch.utils.data.DataLoader`"
--> 266                 "or child of `TrainDataLoaderIter`.".format(type(train_loader))
    267             )
    268

ValueError: `train_loader` has unsupported type: <class 'torchtext.data.iterator.BucketIterator'>.Expected types are `torch.utils.data.DataLoader`or child of `TrainDataLoaderIter`.

It seems like the BucketIterator fails due to some type check.

Thanks for looking into this!

@NaleRaphael
Copy link
Contributor

NaleRaphael commented Jun 14, 2020

@phiweger No, torchtext.data.iterator.BucketIterator is not a subclass of DataLoader, it inherits from the custom class torchtext.data.Iterator instead. And unfortunately, I found this issue that stated some custom modules (e.g. Iterator, Batch, ...) in torchtext < v0.5.0 are incompatible with PyTorch core library. That means we cannot just wrap the returned values of data.BucketIterator.splits() to DataLoader. And it's also not possible to make LRFinder workable by only adding torchtext.data.iterator.BucketIterator into the type checking statement because of the different architecture.

However, there are a few possible solutions for this problem.

  1. As it's also stated in torchtext issue #664, there is a new API for dataset which is compatible with DataLoader. But there are still some details need to be handled (e.g. definition of collate_fn ...), this example (text_classification/train.py) may help you.

    from torch.utils.data import DataLoader
    from torchtext.data.utils import get_tokenizer
    tokenizer = get_tokenizer("spacy")
    train_dataset, test_dataset = IMDB(tokenizer=tokenizer)
    
    train_loader = DataLoader(train_loader, batch_size=..., collate_fn=...)
  2. Use torchtext.datasets.[DATASET_NAME].splits() to split dataset instead, then wrap the splitted dataset to DataLoader. But there are also some configurations need to be done like the previous solution.

    import torchtext as tt
    from torch.utils.data import Dataset, DataLoader
    
    f_text = tt.data.Field(lower=True, include_lengths=True, batch_first=True)
    f_label = tt.data.Field(sequential=False)
    
    # split dataset
    train_set, test_set = tt.datasets.IMDB.splits(f_text, f_label, root='./data')
    
    f_text.build_vocab(train_set, vectors=Glove(name='6B', dim=300))
    f_label.build_vocab(train_set)
    
    # build data loaders
    train_loader = DataLoader(train_set, batch_size=..., collate_fn=...)

@phiweger
Copy link
Author

Thanks @NaleRaphael -- I ended up using a custom bucket iterator which plays nicely w/ the Dataloader class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants