-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torchtext dataloaders #49
Comments
Could you provide more details as to what exactly the issue is? I had a quick look into the dataloaders and datasets from that link and didn't notice anything that would make them unsupported. |
yeah sure, so I'd do something like train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train, dev, test),
batch_size=batch_size,
# batch_sizes=(100, 100, 100),
sort_key=lambda x: len(x.text),
sort_within_batch=True,
device=device)
criterion = nn.BCELoss().to(device)
from torch_lr_finder import LRFinder
optimizer = torch.optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)
lr_finder = LRFinder(model, optimizer, criterion, device="cpu")
lr_finder.range_test(train_iter, end_lr=100, num_iter=100)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-4-07322a04bd5f> in <module>
3 optimizer = torch.optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)
4 lr_finder = LRFinder(model, optimizer, criterion, device="cpu")
----> 5 lr_finder.range_test(train_iter, end_lr=100, num_iter=100)
6 lr_finder.plot() # to inspect the loss-learning rate graph
7 lr_finder.reset() # to reset the model and optimizer to their initial state
.../python3.7/site-packages/torch_lr_finder/lr_finder.py in range_test(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)
264 "`train_loader` has unsupported type: {}."
265 "Expected types are `torch.utils.data.DataLoader`"
--> 266 "or child of `TrainDataLoaderIter`.".format(type(train_loader))
267 )
268
ValueError: `train_loader` has unsupported type: <class 'torchtext.data.iterator.BucketIterator'>.Expected types are `torch.utils.data.DataLoader`or child of `TrainDataLoaderIter`. It seems like the BucketIterator fails due to some type check. Thanks for looking into this! |
@phiweger No, However, there are a few possible solutions for this problem.
|
Thanks @NaleRaphael -- I ended up using a custom bucket iterator which plays nicely w/ the Dataloader class. |
See e.g. https://github.com/pytorch/text.
Because these classes inherit from DataLoader, I think they should work out of the box were it not for the type check.
Kind regards, and thanks for the awesome package!
The text was updated successfully, but these errors were encountered: