Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recommendations regarding use of datapipes for multi-processing, shuffling, DDP, etc. #1755

Merged
merged 9 commits into from
Jun 2, 2022
55 changes: 55 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,61 @@ torchtext.datasets

.. currentmodule:: torchtext.datasets


.. _datapipes_warnings:

.. warning::

The datasets supported by torchtext are datapipes from the `torchdata
project <https://pytorch.org/data/beta/index.html>`_, which is still in Beta
status. This means that the API is subject to change without deprecation
cycles. In particular, we expect a lot of the current idioms to change with
the eventual release of ``DataLoaderV2`` from ``torchdata``.

Here are a few recommendations regarding the use of datapipes:

- For shuffling the datapipe, do that in the DataLoader: ``DataLoader(dp, shuffle=True)``.
You do not need to call ``dp.shuffle()``, because ``torchtext`` has
already done that for you. Note however that the datapipe won't be
shuffled unless you explicitly pass ``shuffle=True`` to the DataLoader.

- When using multi-processing (``num_workers=N``), use the builtin ``worker_init_fn``::

from torch.utils.data.backward_compatibility import worker_init_fn
DataLoader(dp, num_workers=4, worker_init_fn=worker_init_fn, drop_last=True)

This will ensure that data isn't duplicated across workers.

- We also recommend using ``drop_last=True``. Without this, the batch sizes
at the end of an epoch may be very small in some cases (smaller than with
other map-style datasets). This might affect accuracy greatly especially
when batch-norm is used. ``drop_last=True`` ensures that all batch sizes
are equal.

- Distributed training with ``DistributedDataParallel`` is not yet entirely
stable / supported, and we don't recommend it at this point. It will be
better supported in DataLoaderV2. If you still wish to use DDP, make sure
that:

- All workers (DDP workers *and* DataLoader workers) see a different part
of the data. The datasets are already wrapped inside `ShardingFilter
<https://pytorch.org/data/main/generated/torchdata.datapipes.iter.ShardingFilter.html>`_
and you may need to call ``dp.apply_sharing(num_shards, shard_id)`` in order to shard the
data across ranks (DDP workers) and DataLoader workers. One way to do this
is to create ``worker_init_fn`` that calls ``apply_sharding`` with appropriate
number of shards (DDP workers * DataLoader workers) and shard id (inferred through rank
and worker ID of corresponding DataLoader withing rank). Note however, that this assumes
equal number of DataLoader workers for all the ranks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug, @VitalyFedyunin , @ejguan : I made the description rather high level while leaving implementation details to the users. My thinking here is to make the user aware of pitfalls without necessarily providing specific recommendation on implementing specific solution.

- All DDP workers work on the same number of batches. One way to do this
is to by limit the size of the datapipe within each worker to
``len(datapipe) // num_ddp_workers``, but this might not suit all
use-cases.
- The shuffling seed is the same across all workers. You might need to
call ``torch.utils.data.graph_settings.apply_shuffle_seed(dp, rng)``
- The shuffling seed is different across epochs.
- The rest of the RNG (typically used for transformations) is
**different** across workers, for maximal entropy and optimal accuracy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the numbers of total batches fetched across different ranks is different, it could potentially stall the training right? So we somehow need to ensure that the number of samples seen by each rank is same. I wonder what are the workarounds for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it gets more complex real fast :)

In torchvision we have this custom Taker datapipe that limits the size of the dp so that it's consistent across DDP workers https://github.com/pytorch/vision/blob/main/torchvision/prototype/datasets/utils/_internal.py#L144-L145

I'll mention this is as well, but I think we'd rather not directly link to this torchvision snippet, as this is definitely not a recommended practice yet. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not unusual to have different number of samples in each rank, for example, if the datapipe expects the same number of tokens/characters for each batch, each sentence may have different number of tokens/characters, then the batch size will be different. As long as the loss value is normalized in a correct way, the training should be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm... Interseting. My current understanding is that having datapipes with different sizes in DDP will lead to either crashes, or infinite hanging.

Here's a snippet to illustrate this:

import torch
import torch.utils.data as data
import torch.distributed as dist

def replace_print():
    import builtins as __builtin__
    builtin_print = __builtin__.print
    def print(*args, **kwargs):
        for rank in range(dist.get_world_size()):
            if rank == dist.get_rank():
                builtin_print(f"[DDP worker with rank={rank}]", *args, **kwargs)
            dist.barrier()

    __builtin__.print = print


class MyIterableDS(data.IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):

        worker_info = data.get_worker_info()
        num_dl_workers = worker_info.num_workers
        dl_worker_id = worker_info.id

        num_ddp_workers = dist.get_world_size()
        ddp_worker_id = dist.get_rank()
        
        for i, s in enumerate(range(self.size)):
            if i % num_ddp_workers == ddp_worker_id:
                if i % num_dl_workers == dl_worker_id:
                    yield s
        
        # EXTRA SAMPLE
        # Uncomment this and you'll get an error
        # if ddp_worker_id == 0:
        #     yield 100
    
    def __len__(self):
        return self.size

dist.init_process_group(backend="gloo")
replace_print()
dist.barrier()

ds = MyIterableDS()
dl = torch.utils.data.DataLoader(ds, batch_size=10, num_workers=4)

for i, batch in enumerate(dl):    
    print(batch)

You can run this locally (without GPUs) with e.g. torchrun --nproc_per_node=4 scipt.py
Uncommentting the EXTRA SAMPLE part, I get the following error:

Traceback (most recent call last):
  File "/home/nicolashug/dev/vision/lol.py", line 55, in <module>
    print(batch)
  File "/home/nicolashug/dev/vision/lol.py", line 12, in print
    dist.barrier()
  File "/home/nicolashug/.miniconda3/envs/pt/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2785, in barrier
    work.wait()
RuntimeError: [/opt/conda/conda-bld/pytorch_1649142626512/work/third_party/gloo/gloo/transport/tcp/pair.cc:598] Connection closed by peer [fe80::f9d3:3540:64c7:287b]:34704
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 241787) of binary: /home/nicolashug/.miniconda3/envs/pt/bin/python

Did you observe similar behaviour @nateanl ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the total batch number should be the same but the batch size can vary across RANKs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug Oh you are referring the number of batches in each rank. That's correct, DDP will hang forever if the number of batches are not even in all ranks. I was talking about the batch_size and it can be different :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks all, I added a section about all this just above

  • All DDP workers work on the same number of batches. One way to do this is to by limit the size of the datapipe within each worker to len(datapipe) // num_ddp_workers, but this might not suit all use-cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yupp, SGTM!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torchvision we have this custom Taker datapipe that limits the size of the dp so that it's consistent across DDP workers https://github.com/pytorch/vision/blob/main/torchvision/prototype/datasets/utils/_internal.py#L144-L145

Thanks for the code pointer! Sounds interesting :). I guess user would need to know the length of dataset in advance in order to ensure that num_take doesn't exceed the size.

I'll mention this is as well, but I think we'd rather not directly link to this torchvision snippet, as this is definitely not a recommended practice yet. WDYT?

I think I agree! We should probably just make sure users are aware of this issue which you have already addressed here #1755 (comment)

General use cases are as follows: ::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be useful to have a short code snippet using a dataset with DataLoader (with shuffle=True, worker_init_fn)?

I do think the current examples are likely sufficient but I wonder if some users need more hand-holding/guidance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yupp, I guess it would be good to add usage examples with DataLoader. I guess, we can remove worker_init_fn from code examples as @VitalyFedyunin is working on PR that would automatically apply this to all datapipes based datasets :).



Expand Down
6 changes: 5 additions & 1 deletion examples/tutorials/sst2_classification_non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@
# and transforms. Below, we demonstrate how to use text and label processing transforms to pre-process the
# SST-2 dataset.
#
#
# .. note::
# Using datapipes is still currently subject to a few caveats. If you wish
# to extend this example to include shuffling, multi-processing, or
# distributed learning, please see :ref:`this note <datapipes_warnings>`
# for further instructions.

from torchtext.datasets import SST2

Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def _modify_res(t):
def AG_NEWS(root: str, split: Union[Tuple[str], str]):
"""AG_NEWS Dataset

.. warning::

Using datapipes is still currently subject to a few caveats. If you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://paperswithcode.com/dataset/ag-news

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/amazonreviewfull.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def _modify_res(t):
def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
"""AmazonReviewFull Dataset

.. warning::

Using datapipes is still currently subject to a few caveats. If you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://arxiv.org/abs/1509.01626

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def _modify_res(t):
def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
"""AmazonReviewPolarity Dataset

.. warning::

Using datapipes is still currently subject to a few caveats. If you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://arxiv.org/abs/1509.01626

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/cc100.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ def _modify_res(language_code, x):
def CC100(root: str, language_code: str = "en"):
"""CC100 Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://data.statmt.org/cc-100/

Args:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/cola.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def _filter_res(x):
def CoLA(root: str, split: Union[Tuple[str], str]):
"""CoLA dataset

.. warning::

Using datapipes is still currently subject to a few caveats. If you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://nyu-mll.github.io/CoLA/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def _extracted_filepath_fn(root, split, _=None):
def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
"""CoNLL2000Chunking Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://www.clips.uantwerpen.be/conll2000/chunking/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/dbpedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def _modify_res(t):
def DBpedia(root: str, split: Union[Tuple[str], str]):
"""DBpedia Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://www.dbpedia.org/resources/latest-core/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/enwik9.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def _extracted_filepath_fn(root, _=None):
def EnWik9(root: str):
"""EnWik9 dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to http://mattmahoney.net/dc/textdata.html

Number of lines in dataset: 13147026
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def filter_imdb_data(key, fname):
def IMDB(root: str, split: Union[Tuple[str], str]):
"""IMDB Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to http://ai.stanford.edu/~amaas/data/sentiment/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/iwslt2016.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ def IWSLT2016(
):
"""IWSLT2016 dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://wit3.fbk.eu/2016-01

The available datasets include following:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/iwslt2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def _inner_iwslt_tar_filepath_fn(inner_iwslt_tar, _=None):
def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de", "en")):
"""IWSLT2017 dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://wit3.fbk.eu/2017-01

The available datasets include following:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/mnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def _modify_res(x):
def MNLI(root, split):
"""MNLI Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://cims.nyu.edu/~sbowman/multinli/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/mrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def _modify_res(x):
def MRPC(root: str, split: Union[Tuple[str], str]):
"""MRPC Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://www.microsoft.com/en-us/download/details.aspx?id=52398

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/multi30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def _filter_fn(split, language_pair, i, x):
def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] = ("de", "en")):
"""Multi30k dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://www.statmt.org/wmt16/multimodal-task.html#task1

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/penntreebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def _modify_res(t):
def PennTreebank(root, split: Union[Tuple[str], str]):
"""PennTreebank Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://catalog.ldc.upenn.edu/docs/LDC95T7/cl93.html

Number of lines per split:
Expand Down
8 changes: 8 additions & 0 deletions torchtext/datasets/qqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def _modify_res(x):
@_create_dataset_directory(dataset_name=DATASET_NAME)
def QQP(root: str):
"""QQP dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs

Args:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/sogounews.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def _modify_res(t):
def SogouNews(root: str, split: Union[Tuple[str], str]):
"""SogouNews Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://arxiv.org/abs/1509.01626

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/squad1.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def _filepath_fn(root, split, _=None):
def SQuAD1(root: str, split: Union[Tuple[str], str]):
"""SQuAD1 Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://rajpurkar.github.io/SQuAD-explorer/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/squad2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def _filepath_fn(root, split, _=None):
def SQuAD2(root: str, split: Union[Tuple[str], str]):
"""SQuAD2 Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://rajpurkar.github.io/SQuAD-explorer/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/sst2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def _modify_res(t):
def SST2(root, split):
"""SST2 Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://nlp.stanford.edu/sentiment/

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/stsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def _modify_res(x):
def STSB(root, split):
"""STSB Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark

Number of lines per split:
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/udpos.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def _filter_fn(split, x):
def UDPOS(root: str, split: Union[Tuple[str], str]):
"""UDPOS Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

Number of lines per split:
- train: 12543
- valid: 2002
Expand Down
7 changes: 7 additions & 0 deletions torchtext/datasets/wikitext103.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def _filter_fn(split, x):
def WikiText103(root: str, split: Union[Tuple[str], str]):
"""WikiText103 Dataset

.. warning::

using datapipes is still currently subject to a few caveats. if you wish
to use this dataset with shuffling, multi-processing, or distributed
learning, please see :ref:`this note <datapipes_warnings>` for further
instructions.

For additional details refer to https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/

Number of lines per split:
Expand Down
Loading