Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BucketIterator): add num_workers #708

Closed
AmitMY opened this issue Mar 15, 2020 · 6 comments
Closed

feat(BucketIterator): add num_workers #708

AmitMY opened this issue Mar 15, 2020 · 6 comments

Comments

@AmitMY
Copy link

AmitMY commented Mar 15, 2020

🚀 Feature

Add num_workers to BucketIterator

Motivation

While torchtext is designed for test, it is also the best thing to use for sequence data.
I have sequence data in the form of sign language poses, which falls under the category of language, and I want to batch it the same way I would batch text - sorted by length.

For sign language, the dataset needs to handle data augmentation (my specific use case) and current data augmentation libraries like imgaug and albumnations are slow (see issues aleju/imgaug#635 and albumentations-team/albumentations#554).
Therefore, using num_workers to be able to augment a batch or many batches from the iterator will be a great help (Instead of waiting 10 minutes, I would wait 15 seconds, with 40 CPUs)

Pitch

Add num_workers to BucketIterator.
Use num_workers to distribute the Dataset.__getitem__ calls on all workers when iterating the BucketIterator.

Alternatives

Writing my own implementation of a bucket iterator extending DataLoader.

OR

Use a DataLoader with batch size of 1

Other Info

Confirmation that BucketIterator doesn't support num_workers #437

@zhangguanheng66
Copy link
Contributor

We are adding the support of torch.utils.data to torchtext and torch.utils.data.DataLoader comes with num_workers.
@bentrevett @mttk Do you guys think it's worth a PR for this?

@mttk
Copy link
Contributor

mttk commented Mar 24, 2020

@zhangguanheng66 I'd say it would be a welcome contribution. I'm just not sure how worth it is to implement a (larger) feature if the legacy aspect of torchtext will be deprecated.

@harpone could you elaborate where?

@harpone
Copy link

harpone commented Mar 24, 2020

uhh sorry about that, I think I commented on the wrong issue 😬 (first time that's happened)

@AmitMY
Copy link
Author

AmitMY commented Apr 13, 2020

@mttk If this is to be deprecated, I assume that means there's a new way to create data loaders then in torchtext?


Background information that might be relevant to my confusion:

With this toy dataset:

class ToyDataset(Dataset):
    def __init__(self):
        self.data = [{"src": [0,1,2...], "tgt": ["hello", "world"]}, {....}, ...]

        self.fields = {
            "src": RawField(),
            "tgt": Field(init_token="<sos>", eos_token="<eos>", is_target=True)
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Some augmentation happening here that takes some time
        return this.data[idx]

Here is what I have running on v0.5.0:

sort_key = lambda x: len(x["src"])
train_iter = BucketIterator(train_dataset, batch_size=2, sort_key=sort_key, shuffle=True)

And the tgt I get in the batch is a 4x2 tensor, for one sentence with 2 tokens + eos + pad, and one sentence with 3 tokens + eos (all as indexes from the vocabulary).

If instead, I use:

train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

I get an error (for the src):

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 28 and 26 in dimension 1

and if I remove the src, to just have tgt, I get a different error (it doesn't do token-to-index).

@zhangguanheng66
Copy link
Contributor

@AmitMY The new dataset abstraction in #664 and #701 are compatible with DataLoader with multiprocessing support.

@AmitMY
Copy link
Author

AmitMY commented Apr 13, 2020

Perfect! That is very helpful, thank you.

So all I needed to do to make this work was write down a collator:

def text_data_collator(dataset: Dataset):
    def collate(data):
        batch = defaultdict(list)

        for datum in data:
            for name, field in dataset.fields.items():
                batch[name].append(field.preprocess(getattr(datum, name)))

        batch = {name: field.process(batch[name]) for name, field in dataset.fields.items()}

        return batch

    return collate

and replace the iterator with a DataLoader:

collate = text_data_collator(train_dataset)
num_workers = multiprocessing.cpu_count()

train_iter = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate, num_workers=num_workers, shuffle=True)

Loading before an epoch went down from 17 seconds to 1-3 (on a 40 cpu server), so I'll consider this as fixed :)

@AmitMY AmitMY closed this as completed Apr 13, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants