Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data duplicate when setting num_works > 1 with streaming data #3423

Closed
cloudyuyuyu opened this issue Dec 13, 2021 · 14 comments
Closed

data duplicate when setting num_works > 1 with streaming data #3423

cloudyuyuyu opened this issue Dec 13, 2021 · 14 comments
Labels
bug Something isn't working streaming

Comments

@cloudyuyuyu
Copy link

cloudyuyuyu commented Dec 13, 2021

Describe the bug

The data is repeated num_works times when we load_dataset with streaming and set num_works > 1 when construct dataloader

Steps to reproduce the bug

# Sample code to reproduce the bug
import pandas as pd
import numpy as np
import os

from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import shutil

NUM_OF_USER = 1000000
NUM_OF_ACTION = 50000
NUM_OF_SEQUENCE = 10000
NUM_OF_FILES = 32
NUM_OF_WORKERS = 16

if __name__ == "__main__":
    shutil.rmtree("./dataset")
    for i in range(NUM_OF_FILES):
        sequence_data = pd.DataFrame(
            {
                "imei": np.random.randint(1, NUM_OF_USER, size=NUM_OF_SEQUENCE),
                "sequence": np.random.randint(1, NUM_OF_ACTION, size=NUM_OF_SEQUENCE)
            }
        )

        if not os.path.exists("./dataset"):
            os.makedirs("./dataset")

        sequence_data.to_csv(f"./dataset/sequence_data_{i}.csv",
                            
 index=False)

    dataset = load_dataset("csv",
                           data_files=[os.path.join("./dataset",file) for file in os.listdir("./dataset") if file.endswith(".csv")],
                           split="train",
                           streaming=True).with_format("torch")
    data_loader = DataLoader(dataset,
                             batch_size=1024,
                             num_workers=NUM_OF_WORKERS)
    
    result = pd.DataFrame()
    for i, batch in tqdm(enumerate(data_loader)):
        result = pd.concat([result, 
                           pd.DataFrame(batch)],
                           axis=0)
    result.to_csv(f"num_work_{NUM_OF_WORKERS}.csv", index=False)

Expected results

data do not duplicate

Actual results

data duplicate NUM_OF_WORKERS = 16
image

Environment info

  • datasets version:datasets==1.14.0
  • Platform:transformers==4.11.3
  • Python version:3.8
  • PyArrow version:
@cloudyuyuyu cloudyuyuyu added the bug Something isn't working label Dec 13, 2021
@lhoestq
Copy link
Member

lhoestq commented Dec 13, 2021

Hi ! Thanks for reporting :)

When using a PyTorch's data loader with num_workers>1 and an iterable dataset, each worker streams the exact same data by default, resulting in duplicate data when iterating using the data loader.

We can probably fix this in datasets by checking torch.utils.data.get_worker_info() which gives the worker id if it happens.

@cloudyuyuyu
Copy link
Author

Hi ! Thanks for reporting :)

When using a PyTorch's data loader with num_workers>1 and an iterable dataset, each worker streams the exact same data by default, resulting in duplicate data when iterating using the data loader.

We can probably fix this in datasets by checking torch.utils.data.get_worker_info() which gives the worker id if it happens.
Hi ! Thanks for reply

Do u have some plans to fix the problem?

@thomwolf
Copy link
Member

Isn’t that somehow a bug on PyTorch side? (Just asking because this behavior seems quite general and maybe not what would be intended)

@lhoestq
Copy link
Member

lhoestq commented Dec 14, 2021

From PyTorch's documentation here:

When using an IterableDataset with multi-process data loading. The same dataset object is replicated on each worker process, and thus the replicas must be configured differently to avoid duplicated data. See IterableDataset documentations for how to achieve this.

It looks like an intended behavior from PyTorch

As suggested in the docstring of the IterableDataset class, we could pass a worker_init_fn to the DataLoader to fix this. It could be called streaming_worker_init_fn for example.

However, while this solution works, I'm worried that many users simply don't know about this parameter and just start their training with duplicate data without knowing it. That's why I'm more in favor of integrating the check on the worker id directly in datasets in our implementation of IterableDataset.__iter__.

@lhoestq
Copy link
Member

lhoestq commented Jul 6, 2022

Fixed by #4375

@cloudyuyuyu
Copy link
Author

Fixed by #4375

Thanks!

@Ethan-yt
Copy link

Hi there @lhoestq @cloudyuyuyu
I met that problem recently, and #4375 is really useful because I finally found out I am training with duplicate data.
However, in multi-GPU training, I'm using DDP mode and IterableDataset, which still yields duplicate data for each progress. And this is dangerous because users maybe not realize this behavior.

@lhoestq
Copy link
Member

lhoestq commented Oct 18, 2022

If the worker_info.id is unique per process it should work fine, could you check that they're unique ?

The code to get the worker_info in each worker is torch.utils.data.get_worker_info()

@Ethan-yt
Copy link

Ethan-yt commented Oct 28, 2022

test.py

import json
import os

import torch
from torch.utils.data import IterableDataset, DataLoader
from transformers import PreTrainedTokenizer, TrainingArguments

from common.arguments import DataTrainingArguments, ModelArguments


class MyIterableDataset(IterableDataset):
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        print(worker_info)
        return iter(range(3))


if __name__ == '__main__':
    dataset = MyIterableDataset()
    dataloader = DataLoader(dataset, num_workers=1)
    for i in dataloader:
        print(i)
$ python3 -m torch.distributed.launch \
  --nproc_per_node=2 test.py
WorkerInfo(id=0, num_workers=1, seed=5545685212307804959, dataset=<__main__.MyIterableDataset object at 0x7f92648cf6a0>)
WorkerInfo(id=0, num_workers=1, seed=3174108029709729025, dataset=<__main__.MyIterableDataset object at 0x7f19ab961670>)
tensor([0])
tensor([1])
tensor([2])
tensor([0])
tensor([1])
tensor([2])

@lhoestq they are not unique

@lhoestq
Copy link
Member

lhoestq commented Oct 28, 2022

It looks like a bug from pytorch no ? How can we know which data should go in which process when using DDP ?

I guess we need to check torch.distributed.get_world_size() and torch.distributed.get_rank() as well. Not fan of the design here tbh, but that's how it is

@Ethan-yt

This comment was marked as resolved.

@Ethan-yt
Copy link

Never mind. After reading the code, IterableDatasetShard has solved this problem.

@lhoestq
Copy link
Member

lhoestq commented Dec 14, 2022

I'm re-opening this one since I think it should be supported by datasets natively

@lhoestq lhoestq reopened this Dec 14, 2022
@lhoestq
Copy link
Member

lhoestq commented Dec 14, 2022

hmm actually let me open a new issue on DDP - original post was for single node

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working streaming
Projects
None yet
Development

No branches or pull requests

4 participants