Skip to content

Commit

Permalink
Raise warning for unpickable local function (#80232)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#80232

Pull Request resolved: pytorch#547

Fixes pytorch#538
- Improve the validation function to raise warning about unpickable function when either lambda or local function is provided to DataPipe.
- The inner function from functools.partial object is extracted as well for validation
- Mimic the behavior of pickle module for local lambda function: It would only raise Error for the local function rather than lambda function. So, we will raise warning about local function not lambda function.
```py

>>> import pickle
>>> def fn():
...     lf = lambda x: x
...     pickle.dumps(lf)
>>> pickle.dumps(fn)
AttributeError: Can't pickle local object 'fn.<locals>.<lambda>'
```

This Diff also fixes the Error introduced by pytorch/pytorch#79344

Differential Revision: D37417556

fbshipit-source-id: 7213ee84b34092e0c2cf293ff8bf1dc56659fc83
  • Loading branch information
ejguan authored and facebook-github-bot committed Jun 24, 2022
1 parent 75f31dc commit 0cc4106
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
12 changes: 5 additions & 7 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,15 +374,13 @@ def test_serializable_with_dill(self):
else:
dp_no_attribute_error = (iterdp.OnDiskCacheHolder,)
try:
with warnings.catch_warnings(record=True) as wa:
with self.assertWarnsRegex(UserWarning, r"^Local function is not supported by pickle"):
datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg]
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle")
if isinstance(datapipe, dp_no_attribute_error):
if isinstance(datapipe, dp_no_attribute_error):
_ = pickle.dumps(datapipe)
else:
with self.assertRaises(AttributeError):
_ = pickle.dumps(datapipe)
else:
with self.assertRaises(AttributeError):
_ = pickle.dumps(datapipe)
except Exception as e:
print(f"{dpipe} is failing.")
raise e
Expand Down
6 changes: 3 additions & 3 deletions torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Callable, Iterator, List, TypeVar

from torch.utils.data import functional_datapipe, IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

T_co = TypeVar("T_co", covariant=True)

Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
) -> None:
self.datapipe = datapipe

_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]

assert batch_size > 0, "Batch size is required to be larger than 0!"
Expand Down Expand Up @@ -118,7 +118,7 @@ class FlatMapperIterDataPipe(IterDataPipe[T_co]):
def __init__(self, datapipe: IterDataPipe, fn: Callable, input_col=None) -> None:
self.datapipe = datapipe

_check_lambda_fn(fn)
_check_unpickable_fn(fn)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col

Expand Down
5 changes: 3 additions & 2 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
raise


from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE

from torch.utils.data.graph import traverse
from torchdata.datapipes import functional_datapipe
Expand Down Expand Up @@ -184,7 +184,8 @@ def __init__(
):
self.source_datapipe = source_datapipe

_check_lambda_fn(filepath_fn)
if filepath_fn is not None:
_check_unpickable_fn(filepath_fn)
filepath_fn = _generator_to_list(filepath_fn) if inspect.isgeneratorfunction(filepath_fn) else filepath_fn

if hash_dict is not None and hash_type not in ("sha256", "md5"):
Expand Down
12 changes: 6 additions & 6 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Callable, Iterator, Optional, TypeVar

from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

T_co = TypeVar("T_co", covariant=True)

Expand Down Expand Up @@ -64,14 +64,14 @@ def __init__(
raise TypeError(f"ref_datapipe must be a IterDataPipe, but its type is {type(ref_datapipe)} instead.")
self.source_datapipe = source_datapipe
self.ref_datapipe = ref_datapipe
_check_lambda_fn(key_fn)
_check_unpickable_fn(key_fn)
self.key_fn = key_fn
if ref_key_fn is not None:
_check_lambda_fn(ref_key_fn)
_check_unpickable_fn(ref_key_fn)
self.ref_key_fn = key_fn if ref_key_fn is None else ref_key_fn
self.keep_key = keep_key
if merge_fn is not None:
_check_lambda_fn(merge_fn)
_check_unpickable_fn(merge_fn)
self.merge_fn = merge_fn
if buffer_size is not None and buffer_size <= 0:
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
Expand Down Expand Up @@ -185,10 +185,10 @@ def __init__(
raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.")
self.source_iterdatapipe: IterDataPipe = source_iterdatapipe
self.map_datapipe: MapDataPipe = map_datapipe
_check_lambda_fn(key_fn)
_check_unpickable_fn(key_fn)
self.key_fn: Callable = key_fn
if merge_fn is not None:
_check_lambda_fn(merge_fn)
_check_unpickable_fn(merge_fn)
self.merge_fn: Optional[Callable] = merge_fn
self.length: int = -1

Expand Down
5 changes: 3 additions & 2 deletions torchdata/datapipes/iter/util/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Callable, Dict, Optional

from torch.utils.data import IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE

if DILL_AVAILABLE:
import dill
Expand Down Expand Up @@ -52,7 +52,8 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
if not isinstance(datapipe, IterDataPipe):
raise TypeError(f"IterToMapConverter can only apply on IterDataPipe, but found {type(datapipe)}")
self.datapipe = datapipe
_check_lambda_fn(key_value_fn)
if key_value_fn is not None:
_check_unpickable_fn(key_value_fn)
self.key_value_fn = key_value_fn # type: ignore[assignment]
self._map = None
self._length = -1
Expand Down
4 changes: 2 additions & 2 deletions torchdata/datapipes/iter/util/paragraphaggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Callable, Iterator, List, Tuple, TypeVar

from torch.utils.data.datapipes.utils.common import _check_lambda_fn
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
Expand Down Expand Up @@ -44,7 +44,7 @@ class ParagraphAggregatorIterDataPipe(IterDataPipe[Tuple[str, str]]):

def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Callable = _default_line_join) -> None:
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
_check_lambda_fn(joiner)
_check_unpickable_fn(joiner)
self.joiner: Callable = joiner
self.buffer: List = []

Expand Down

0 comments on commit 0cc4106

Please sign in to comment.