Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise warning for unpickable local function #547

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,13 @@ def test_serializable_with_dill(self):
else:
dp_no_attribute_error = (iterdp.OnDiskCacheHolder,)
try:
with warnings.catch_warnings(record=True) as wa:
with self.assertWarnsRegex(UserWarning, r"^Local function is not supported by pickle"):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
if isinstance(datapipe, dp_no_attribute_error):
if isinstance(datapipe, dp_no_attribute_error):
_ = pickle.dumps(datapipe)
else:
with self.assertRaises(AttributeError):
_ = pickle.dumps(datapipe)
else:
with self.assertRaises(AttributeError):
_ = pickle.dumps(datapipe)
except Exception as e:
print(f"{dpipe} is failing.")
raise e
Expand Down
6 changes: 3 additions & 3 deletions torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Callable, Iterator, List, TypeVar

from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

T_co = TypeVar("T_co", covariant=True)

Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
) -> None:
self.datapipe = datapipe

_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]

assert batch_size > 0, "Batch size is required to be larger than 0!"
Expand Down Expand Up @@ -118,7 +118,7 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
def __init__(self, datapipe: IterDataPipe, fn: Callable, input_col=None) -> None:
self.datapipe = datapipe

_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col

Expand Down
5 changes: 3 additions & 2 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
raise


from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE

from torch.utils.data.graph import traverse
from torchdata.datapipes import functional_datapipe
Expand Down Expand Up @@ -184,7 +184,8 @@ def __init__(
):
self.source_datapipe = source_datapipe

_check_lambda_fn(filepath_fn)
if filepath_fn is not None:
_check_unpickable_fn(filepath_fn)
filepath_fn = _generator_to_list(filepath_fn) if inspect.isgeneratorfunction(filepath_fn) else filepath_fn

if hash_dict is not None and hash_type not in ("sha256", "md5"):
Expand Down
12 changes: 6 additions & 6 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Callable, Iterator, Optional, TypeVar

from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

T_co = TypeVar("T_co", covariant=True)

Expand Down Expand Up @@ -64,14 +64,14 @@ def __init__(
raise TypeError(f"ref_datapipe must be a IterDataPipe, but its type is {type(ref_datapipe)} instead.")
self.source_datapipe = source_datapipe
self.ref_datapipe = ref_datapipe
_check_lambda_fn(key_fn)
_check_unpickable_fn(key_fn)
self.key_fn = key_fn
if ref_key_fn is not None:
_check_lambda_fn(ref_key_fn)
_check_unpickable_fn(ref_key_fn)
self.ref_key_fn = key_fn if ref_key_fn is None else ref_key_fn
self.keep_key = keep_key
if merge_fn is not None:
_check_lambda_fn(merge_fn)
_check_unpickable_fn(merge_fn)
self.merge_fn = merge_fn
if buffer_size is not None and buffer_size <= 0:
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
Expand Down Expand Up @@ -185,10 +185,10 @@ def __init__(
raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.")
self.source_iterdatapipe: IterDataPipe = source_iterdatapipe
self.map_datapipe: MapDataPipe = map_datapipe
_check_lambda_fn(key_fn)
_check_unpickable_fn(key_fn)
self.key_fn: Callable = key_fn
if merge_fn is not None:
_check_lambda_fn(merge_fn)
_check_unpickable_fn(merge_fn)
self.merge_fn: Optional[Callable] = merge_fn
self.length: int = -1

Expand Down
5 changes: 3 additions & 2 deletions torchdata/datapipes/iter/util/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Callable, Dict, Optional

from torch.utils.data import IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE

if DILL_AVAILABLE:
import dill
Expand Down Expand Up @@ -52,7 +52,8 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
if not isinstance(datapipe, IterDataPipe):
raise TypeError(f"IterToMapConverter can only apply on IterDataPipe, but found {type(datapipe)}")
self.datapipe = datapipe
_check_lambda_fn(key_value_fn)
if key_value_fn is not None:
_check_unpickable_fn(key_value_fn)
self.key_value_fn = key_value_fn # type: ignore[assignment]
self._map = None
self._length = -1
Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/paragraphaggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Callable, Iterator, List, Tuple, TypeVar

from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
Expand Down Expand Up @@ -44,7 +44,7 @@ class ParagraphAggregatorIterDataPipe(IterDataPipe[Tuple[str, str]]):

def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Callable = _default_line_join) -> None:
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
_check_lambda_fn(joiner)
_check_unpickable_fn(joiner)
self.joiner: Callable = joiner
self.buffer: List = []

Expand Down