diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index c1187c451..3762cd5a1 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -7,7 +7,7 @@ from typing import Callable, Iterator, List, TypeVar from torch.utils.data import functional_datapipe, IterDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn T_co = TypeVar("T_co", covariant=True) @@ -59,7 +59,7 @@ def __init__( ) -> None: self.datapipe = datapipe - _check_lambda_fn(fn) + _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] assert batch_size > 0, "Batch size is required to be larger than 0!" @@ -118,7 +118,7 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]): def __init__(self, datapipe: IterDataPipe, fn: Callable, input_col=None) -> None: self.datapipe = datapipe - _check_lambda_fn(fn) + _check_unpickable_fn(fn) self.fn = fn # type: ignore[assignment] self.input_col = input_col diff --git a/torchdata/datapipes/iter/util/cacheholder.py b/torchdata/datapipes/iter/util/cacheholder.py index ce78e70b2..b3723a368 100644 --- a/torchdata/datapipes/iter/util/cacheholder.py +++ b/torchdata/datapipes/iter/util/cacheholder.py @@ -27,7 +27,7 @@ raise -from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE from torch.utils.data.graph import traverse from torchdata.datapipes import functional_datapipe @@ -184,7 +184,8 @@ def __init__( ): self.source_datapipe = source_datapipe - _check_lambda_fn(filepath_fn) + if filepath_fn is not None: + _check_unpickable_fn(filepath_fn) filepath_fn = _generator_to_list(filepath_fn) if inspect.isgeneratorfunction(filepath_fn) else filepath_fn if hash_dict is not None and hash_type not in ("sha256", "md5"): diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index 7c094bc57..a7148d26f 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -9,7 +9,7 @@ from typing import Callable, Iterator, Optional, TypeVar from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn T_co = TypeVar("T_co", covariant=True) @@ -64,14 +64,14 @@ def __init__( raise TypeError(f"ref_datapipe must be a IterDataPipe, but its type is {type(ref_datapipe)} instead.") self.source_datapipe = source_datapipe self.ref_datapipe = ref_datapipe - _check_lambda_fn(key_fn) + _check_unpickable_fn(key_fn) self.key_fn = key_fn if ref_key_fn is not None: - _check_lambda_fn(ref_key_fn) + _check_unpickable_fn(ref_key_fn) self.ref_key_fn = key_fn if ref_key_fn is None else ref_key_fn self.keep_key = keep_key if merge_fn is not None: - _check_lambda_fn(merge_fn) + _check_unpickable_fn(merge_fn) self.merge_fn = merge_fn if buffer_size is not None and buffer_size <= 0: raise ValueError("'buffer_size' is required to be either None or a positive integer.") @@ -185,10 +185,10 @@ def __init__( raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.") self.source_iterdatapipe: IterDataPipe = source_iterdatapipe self.map_datapipe: MapDataPipe = map_datapipe - _check_lambda_fn(key_fn) + _check_unpickable_fn(key_fn) self.key_fn: Callable = key_fn if merge_fn is not None: - _check_lambda_fn(merge_fn) + _check_unpickable_fn(merge_fn) self.merge_fn: Optional[Callable] = merge_fn self.length: int = -1 diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 7c6d62751..e4adaafdb 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -9,7 +9,7 @@ from typing import Callable, Dict, Optional from torch.utils.data import IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE if DILL_AVAILABLE: import dill @@ -52,7 +52,8 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No if not isinstance(datapipe, IterDataPipe): raise TypeError(f"IterToMapConverter can only apply on IterDataPipe, but found {type(datapipe)}") self.datapipe = datapipe - _check_lambda_fn(key_value_fn) + if key_value_fn is not None: + _check_unpickable_fn(key_value_fn) self.key_value_fn = key_value_fn # type: ignore[assignment] self._map = None self._length = -1 diff --git a/torchdata/datapipes/iter/util/paragraphaggregator.py b/torchdata/datapipes/iter/util/paragraphaggregator.py index 696ba33fa..be7c21daf 100644 --- a/torchdata/datapipes/iter/util/paragraphaggregator.py +++ b/torchdata/datapipes/iter/util/paragraphaggregator.py @@ -6,7 +6,7 @@ from typing import Callable, Iterator, List, Tuple, TypeVar -from torch.utils.data.datapipes.utils.common import _check_lambda_fn +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -44,7 +44,7 @@ class ParagraphAggregatorIterDataPipe(IterDataPipe[Tuple[str, str]]): def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Callable = _default_line_join) -> None: self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe - _check_lambda_fn(joiner) + _check_unpickable_fn(joiner) self.joiner: Callable = joiner self.buffer: List = []