-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(BucketIterator): add num_workers #708
Comments
We are adding the support of |
@zhangguanheng66 I'd say it would be a welcome contribution. I'm just not sure how worth it is to implement a (larger) feature if the legacy aspect of torchtext will be deprecated. @harpone could you elaborate where? |
uhh sorry about that, I think I commented on the wrong issue 😬 (first time that's happened) |
@mttk If this is to be deprecated, I assume that means there's a new way to create data loaders then in torchtext? Background information that might be relevant to my confusion:With this toy dataset: class ToyDataset(Dataset):
def __init__(self):
self.data = [{"src": [0,1,2...], "tgt": ["hello", "world"]}, {....}, ...]
self.fields = {
"src": RawField(),
"tgt": Field(init_token="<sos>", eos_token="<eos>", is_target=True)
}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# Some augmentation happening here that takes some time
return this.data[idx] Here is what I have running on v0.5.0: sort_key = lambda x: len(x["src"])
train_iter = BucketIterator(train_dataset, batch_size=2, sort_key=sort_key, shuffle=True) And the If instead, I use: train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) I get an error (for the
and if I remove the |
Perfect! That is very helpful, thank you. So all I needed to do to make this work was write down a collator: def text_data_collator(dataset: Dataset):
def collate(data):
batch = defaultdict(list)
for datum in data:
for name, field in dataset.fields.items():
batch[name].append(field.preprocess(getattr(datum, name)))
batch = {name: field.process(batch[name]) for name, field in dataset.fields.items()}
return batch
return collate and replace the iterator with a collate = text_data_collator(train_dataset)
num_workers = multiprocessing.cpu_count()
train_iter = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate, num_workers=num_workers, shuffle=True) Loading before an epoch went down from 17 seconds to 1-3 (on a 40 cpu server), so I'll consider this as fixed :) |
🚀 Feature
Add
num_workers
toBucketIterator
Motivation
While
torchtext
is designed for test, it is also the best thing to use for sequence data.I have sequence data in the form of sign language poses, which falls under the category of language, and I want to batch it the same way I would batch text - sorted by length.
For sign language, the dataset needs to handle data augmentation (my specific use case) and current data augmentation libraries like
imgaug
andalbumnations
are slow (see issues aleju/imgaug#635 and albumentations-team/albumentations#554).Therefore, using
num_workers
to be able to augment a batch or many batches from the iterator will be a great help (Instead of waiting 10 minutes, I would wait 15 seconds, with 40 CPUs)Pitch
Add
num_workers
toBucketIterator
.Use
num_workers
to distribute theDataset.__getitem__
calls on all workers when iterating theBucketIterator
.Alternatives
Writing my own implementation of a bucket iterator extending
DataLoader
.OR
Use a
DataLoader
with batch size of 1Other Info
Confirmation that BucketIterator doesn't support
num_workers
#437The text was updated successfully, but these errors were encountered: