Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotation to get_worker_info #87017

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions torch/utils/data/_utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import queue
from dataclasses import dataclass
from torch._utils import ExceptionWrapper
from typing import Optional, Union
from typing import Optional, Union, TYPE_CHECKING
from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS, HAS_NUMPY
if TYPE_CHECKING:
from torch.utils.data import Dataset

if IS_WINDOWS:
import ctypes
Expand Down Expand Up @@ -60,6 +62,10 @@ def is_alive(self):


class WorkerInfo(object):
id: int
num_workers: int
seed: int
dataset: 'Dataset'
__initialized = False

def __init__(self, **kwargs):
Expand All @@ -80,7 +86,7 @@ def __repr__(self):
return '{}({})'.format(self.__class__.__name__, ', '.join(items))


def get_worker_info():
def get_worker_info() -> Optional[WorkerInfo]:
r"""Returns the information about the current
:class:`~torch.utils.data.DataLoader` iterator worker process.

Expand Down
2 changes: 2 additions & 0 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def _get_distributed_settings():
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
global_worker_id = worker_id
info = torch.utils.data.get_worker_info()
assert info is not None
total_workers = info.num_workers
datapipe = info.dataset
assert isinstance(datapipe, IterDataPipe)
# To distribute elements across distributed process evenly, we should shard data on distributed
# processes first then shard on worker processes
total_workers *= world_size
Expand Down