Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3402 Add support to set other pickle related args #3412

Merged
merged 14 commits into from
Nov 29, 2021
43 changes: 37 additions & 6 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def __init__(
transform: Union[Sequence[Callable], Callable],
cache_dir: Optional[Union[Path, str]],
hash_func: Callable[..., bytes] = pickle_hashing,
pickle_module=pickle,
pickle_protocol=pickle.DEFAULT_PROTOCOL,
) -> None:
"""
Args:
Expand All @@ -167,13 +169,21 @@ def __init__(
If `cache_dir` is `None`, there is effectively no caching.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
pickle_module: module used for pickling metadata and objects, default to `pickle`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
self.pickle_module = pickle_module
self.pickle_protocol = pickle_protocol
if self.cache_dir is not None:
if not self.cache_dir.exists():
self.cache_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -274,7 +284,12 @@ def _cachecheck(self, item_transformed):
# which may leave partially written cache files in an incomplete state
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(_item_transformed, temp_hash_file)
torch.save(
obj=_item_transformed,
f=temp_hash_file,
pickle_module=self.pickle_module,
pickle_protocol=self.pickle_protocol,
)
if temp_hash_file.is_file() and not hashfile.is_file():
# On Unix, if target exists and is a file, it will be replaced silently if the user has permission.
# for more details: https://docs.python.org/3/library/shutil.html#shutil.move.
Expand Down Expand Up @@ -302,6 +317,8 @@ def __init__(
cache_n_trans: int,
cache_dir: Optional[Union[Path, str]],
hash_func: Callable[..., bytes] = pickle_hashing,
pickle_module=pickle,
pickle_protocol=pickle.DEFAULT_PROTOCOL,
) -> None:
"""
Args:
Expand All @@ -318,9 +335,22 @@ def __init__(
If `cache_dir` is `None`, there is effectively no caching.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.

"""
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func)
pickle_module: module used for pickling metadata and objects, default to `pickle`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

"""
super().__init__(
data=data,
transform=transform,
cache_dir=cache_dir,
hash_func=hash_func,
pickle_module=pickle_module,
pickle_protocol=pickle_protocol,
)
self.cache_n_trans = cache_n_trans

def _pre_transform(self, item_transformed):
Expand Down Expand Up @@ -407,12 +437,13 @@ def __init__(
lmdb_kwargs: additional keyword arguments to the lmdb environment.
for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class
"""
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func)
super().__init__(
data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, pickle_protocol=pickle_protocol
)
self.progress = progress
if not self.cache_dir:
raise ValueError("cache_dir must be specified.")
self.db_file = self.cache_dir / f"{db_name}.lmdb"
self.pickle_protocol = pickle_protocol
self.lmdb_kwargs = lmdb_kwargs or {}
if not self.lmdb_kwargs.get("map_size", 0):
self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,16 @@ class Transform(ABC):
subsequent transforms in a composed transform.
#. storing too much information in ``data`` may cause some memory issue or IPC sync issue,
especially in the multi-processing environment of PyTorch DataLoader.
#. transforms should add data types to the `backend` list if they are capable of performing a transform
without modifying the input type. For example, [\"torch.Tensor\", \"np.ndarray\"] means that
no copies of the data are required if the input is either \"torch.Tensor\" or \"np.ndarray\".

See Also

:py:class:`monai.transforms.Compose`
"""

backend: List[TransformBackends] = []
"""Transforms should add data types to this list if they are capable of performing a transform without
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
modifying the input type. For example, [\"torch.Tensor\", \"np.ndarray\"] means that no copies of the data
are required if the input is either \"torch.Tensor\" or \"np.ndarray\"."""

@abstractmethod
def __call__(self, data: Any):
Expand Down
43 changes: 39 additions & 4 deletions monai/utils/state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import copy
import os
import pickle
import tempfile
from typing import Dict, Optional

Expand All @@ -36,7 +37,14 @@ class StateCacher:
>>> model.load_state_dict(state_cacher.retrieve("model"))
"""

def __init__(self, in_memory: bool, cache_dir: Optional[PathLike] = None, allow_overwrite: bool = True) -> None:
def __init__(
self,
in_memory: bool,
cache_dir: Optional[PathLike] = None,
allow_overwrite: bool = True,
pickle_module=pickle,
pickle_protocol=pickle.DEFAULT_PROTOCOL,
) -> None:
"""Constructor.

Args:
Expand All @@ -48,10 +56,19 @@ def __init__(self, in_memory: bool, cache_dir: Optional[PathLike] = None, allow_
allow_overwrite: allow the cache to be overwritten. If set to `False`, an
error will be thrown if a matching already exists in the list of cached
objects.
pickle_module: module used for pickling metadata and objects, default to `pickle`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

"""
self.in_memory = in_memory
self.cache_dir = cache_dir
self.allow_overwrite = allow_overwrite
self.pickle_module = pickle_module
self.pickle_protocol = pickle_protocol

if self.cache_dir is None:
self.cache_dir = tempfile.gettempdir()
Expand All @@ -60,16 +77,34 @@ def __init__(self, in_memory: bool, cache_dir: Optional[PathLike] = None, allow_

self.cached: Dict[str, str] = {}

def store(self, key, data_obj):
"""Store a given object with the given key name."""
def store(self, key, data_obj, pickle_module=None, pickle_protocol=None):
"""
Store a given object with the given key name.

Args:
key: key of the data object to store.
data_obj: data object to store.
pickle_module: module used for pickling metadata and objects, default to `self.pickle_module`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
pickle_protocol: can be specified to override the default protocol, default to `self.pickle_protocol`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

"""
if key in self.cached and not self.allow_overwrite:
raise RuntimeError("Cached key already exists and overwriting is disabled.")
if self.in_memory:
self.cached.update({key: {"obj": copy.deepcopy(data_obj)}})
else:
fn = os.path.join(self.cache_dir, f"state_{key}_{id(self)}.pt")
self.cached.update({key: {"obj": fn}})
torch.save(data_obj, fn)
torch.save(
obj=data_obj,
f=fn,
pickle_module=self.pickle_module if pickle_module is None else pickle_module,
pickle_protocol=self.pickle_protocol if pickle_protocol is None else pickle_protocol,
)
# store object's device if relevant
if hasattr(data_obj, "device"):
self.cached[key]["device"] = data_obj.device
Expand Down
9 changes: 8 additions & 1 deletion tests/test_persistentdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import os
import pickle
import tempfile
import unittest

Expand Down Expand Up @@ -56,7 +57,13 @@ def test_cache(self):
items = [[list(range(i))] for i in range(5)]

with tempfile.TemporaryDirectory() as tempdir:
ds = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir)
ds = PersistentDataset(
data=items,
transform=_InplaceXform(),
cache_dir=tempdir,
pickle_module=pickle,
pickle_protocol=pickle.HIGHEST_PROTOCOL,
)
self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]])
ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir)
self.assertEqual(list(ds1), list(ds))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scale_intensity_range_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_relative_scaling(self):

for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(result, p(expected_img), rtol=1e-4)
assert_allclose(result, p(expected_img), rtol=1e-3)

scaler = ScaleIntensityRangePercentiles(
lower=lower, upper=upper, b_min=b_min, b_max=b_max, relative=True, clip=True
Expand Down
8 changes: 6 additions & 2 deletions tests/test_state_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import unittest
from os.path import exists, join
from pathlib import Path
Expand All @@ -22,7 +23,10 @@
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}]
TEST_CASE_1 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": gettempdir()}]
TEST_CASE_1 = [
torch.Tensor([1]).to(DEVICE),
{"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL},
]
TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}]
TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}]

Expand All @@ -37,7 +41,7 @@ def test_state_cacher(self, data_obj, params):

state_cacher = StateCacher(**params)
# store it
state_cacher.store(key, data_obj)
state_cacher.store(key, data_obj, pickle_module=pickle)
# create clone then modify original
data_obj_orig = data_obj.clone()
data_obj += 1
Expand Down