Skip to content

Commit

Permalink
Define and use _T_co locally instead of importing from torch.utils.da…
Browse files Browse the repository at this point in the history
…ta DataLoader (#1282)
  • Loading branch information
gokulavasan authored Jul 4, 2024
1 parent b421e86 commit f827b0d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import queue
import threading

from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, TypeVar, Union

import torch
import torch.multiprocessing as multiprocessing
Expand Down Expand Up @@ -76,9 +76,10 @@
default_collate,
default_convert,
get_worker_info,
T_co,
)

_T_co = TypeVar("_T_co", covariant=True)

logger = logging.getLogger(__name__)

_INDEX_SAMPLER_STATE = "_index_sampler_state"
Expand All @@ -89,7 +90,7 @@
_ITERATOR_FINISHED = "_iterator_finished"


class StatefulDataLoader(DataLoader[T_co]):
class StatefulDataLoader(DataLoader[_T_co]):
r"""
This is a drop in replacement for :class:`~torch.utils.data.DataLoader`
that implements state_dict and load_state_dict methods, enabling mid-epoch
Expand Down Expand Up @@ -183,7 +184,7 @@ class StatefulDataLoader(DataLoader[T_co]):

def __init__(
self,
dataset: Dataset[T_co],
dataset: Dataset[_T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[Sampler, Iterable, None] = None,
Expand Down

0 comments on commit f827b0d

Please sign in to comment.