diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index e358cd4b99..df8f09bf4b 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -133,7 +133,7 @@ def get_nnunet_trainer( cudnn.benchmark = True if pretrained_model is not None: - state_dict = torch.load(pretrained_model) + state_dict = torch.load(pretrained_model, weights_only=True) if "network_weights" in state_dict: nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) return nnunet_trainer @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name parameters = [] checkpoint = torch.load( - join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), + map_location=torch.device("cpu"), + weights_only=True, ) trainer_name = checkpoint["trainer_name"] configuration_name = checkpoint["init_args"]["configuration"] @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name else None ) if Path(model_training_output_dir).joinpath(model_name).is_file(): - monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) + monai_checkpoint = torch.load( + join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True + ) if "network_weights" in monai_checkpoint.keys(): parameters.append(monai_checkpoint["network_weights"]) else: @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" ) - nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) - nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) + nnunet_checkpoint_final = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True + ) + nnunet_checkpoint_best = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True + ) nnunet_checkpoint = {} nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] @@ -470,7 +478,7 @@ def get_network_from_nnunet_plans( if model_ckpt is None: return network else: - state_dict = torch.load(model_ckpt) + state_dict = torch.load(model_ckpt, weights_only=True) network.load_state_dict(state_dict[model_key_in_ckpt]) return network @@ -534,7 +542,7 @@ def subfiles( Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True) - nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") + nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True) latest_checkpoints: list[str] = subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True ) @@ -545,7 +553,7 @@ def subfiles( epochs.sort() final_epoch: int = epochs[-1] monai_last_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True ) best_checkpoints: list[str] = subfiles( @@ -558,7 +566,7 @@ def subfiles( key_metrics.sort() best_key_metric: str = key_metrics[-1] monai_best_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True ) nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"] diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 340c5eb8fa..5e367cc297 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -78,7 +78,6 @@ from .thread_buffer import ThreadBuffer, ThreadDataLoader from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( - PICKLE_KEY_SUFFIX, affine_to_spacing, compute_importance_map, compute_shape_offset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index e5842bfa7a..d63ff32293 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -13,7 +13,6 @@ import collections.abc import math -import pickle import shutil import sys import tempfile @@ -22,9 +21,11 @@ import warnings from collections.abc import Callable, Sequence from copy import copy, deepcopy +from io import BytesIO from multiprocessing.managers import ListProxy from multiprocessing.pool import ThreadPool from pathlib import Path +from pickle import UnpicklingError from typing import IO, TYPE_CHECKING, Any, cast import numpy as np @@ -207,6 +208,11 @@ class PersistentDataset(Dataset): not guaranteed, so caution should be used when modifying transforms to avoid unexpected errors. If in doubt, it is advisable to clear the cache directory. + Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will + be converted to tensors, however any other object type returned by transforms will not be loadable since + `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects. + Legacy cache files may not be loadable and may need to be recomputed. + Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to its documentation to familiarize yourself with the interaction between `PersistentDataset` and @@ -248,8 +254,8 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). @@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: - return torch.load(hashfile, weights_only=False) + return torch.load(hashfile, weights_only=True) except PermissionError as e: if sys.platform != "win32": raise e - except RuntimeError as e: - if "Invalid magic number; corrupt file" in str(e): + except (UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed + if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError): warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.") hashfile.unlink() else: @@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed): with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name torch.save( - obj=_item_transformed, + obj=convert_to_tensor(_item_transformed, convert_numeric=False), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -455,8 +461,8 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). @@ -531,7 +537,7 @@ def __init__( hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", progress: bool = True, - pickle_protocol=pickle.HIGHEST_PROTOCOL, + pickle_protocol=DEFAULT_PROTOCOL, hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, lmdb_kwargs: dict | None = None, @@ -551,8 +557,9 @@ def __init__( defaults to `monai.data.utils.pickle_hashing`. db_name: lmdb database file name. Defaults to "monai_cache". progress: whether to display a progress bar. - pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL. - https://docs.python.org/3/library/pickle.html#pickle-protocols + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. @@ -594,6 +601,15 @@ def set_data(self, data: Sequence): super().set_data(data=data) self._read_env = self._fill_cache_start_reader(show_progress=self.progress) + def _safe_serialize(self, val): + out = BytesIO() + torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol) + out.seek(0) + return out.read() + + def _safe_deserialize(self, val): + return torch.load(BytesIO(val), map_location="cpu", weights_only=True) + def _fill_cache_start_reader(self, show_progress=True): """ Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write. @@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True): continue if val is None: val = self._pre_transform(deepcopy(item)) # keep the original hashed - val = pickle.dumps(val, protocol=self.pickle_protocol) + # val = pickle.dumps(val, protocol=self.pickle_protocol) + val = self._safe_serialize(val) with env.begin(write=True) as txn: txn.put(key, val) done = True @@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed): warnings.warn("LMDBDataset: cache key not found, running fallback caching.") return super()._cachecheck(item_transformed) try: - return pickle.loads(data) + # return pickle.loads(data) + return self._safe_deserialize(data) except Exception as err: raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err @@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): meta_hash_file = self.cache_dir / meta_hash_file_name temp_hash_file = Path(tmpdirname) / meta_hash_file_name torch.save( - obj=self._meta_cache[meta_hash_file_name], + obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name): if meta_hash_file_name in self._meta_cache: return self._meta_cache[meta_hash_file_name] else: - return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False) + return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 6425bc0a4f..12bd76ba60 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -611,4 +611,4 @@ def print_verbose(self) -> None: # needed in later versions of Pytorch to indicate the class is safe for serialisation if hasattr(torch.serialization, "add_safe_globals"): - torch.serialization.add_safe_globals([MetaTensor]) + torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys]) diff --git a/monai/data/utils.py b/monai/data/utils.py index 14217e9103..ca7d5c9d9e 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -30,7 +30,6 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai import config from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.data.meta_obj import MetaObj from monai.utils import ( @@ -93,7 +92,6 @@ "remove_keys", "remove_extra_metadata", "get_extra_metadata_keys", - "PICKLE_KEY_SUFFIX", "is_no_channel", ] @@ -418,32 +416,6 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): return -PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX - - -def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): - """ - Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate. - - Args: - data: a list or dictionary with substructures to be pickled/unpickled. - key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`). - is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False). - """ - if isinstance(data, Mapping): - data = dict(data) - for k in data: - if f"{k}".endswith(key): - if is_encode and not isinstance(data[k], bytes): - data[k] = pickle.dumps(data[k], 0) - if not is_encode and isinstance(data[k], bytes): - data[k] = pickle.loads(data[k]) - return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()} - elif isinstance(data, (list, tuple)): - return [pickle_operations(item, key=key, is_encode=is_encode) for item in data] - return data - - def collate_meta_tensor_fn(batch, *, collate_fn_map=None): """ Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` @@ -500,8 +472,8 @@ def list_data_collate(batch: Sequence): key = None collate_fn = default_collate try: - if config.USE_META_DICT: - data = pickle_operations(data) # bc 0.9.0 + # if config.USE_META_DICT: + # data = pickle_operations(data) # bc 0.9.0 if isinstance(elem, Mapping): ret = {} for k in elem: @@ -654,15 +626,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if isinstance(deco, Mapping): _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values()) ret = [dict(zip(deco, item)) for item in _gen] - if not config.USE_META_DICT: - return ret - return pickle_operations(ret, is_encode=False) # bc 0.9.0 + # if not config.USE_META_DICT: + # return ret + # return pickle_operations(ret, is_encode=False) # bc 0.9.0 + return ret if isinstance(deco, Iterable): _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco) ret_list = [list(item) for item in _gen] - if not config.USE_META_DICT: - return ret_list - return pickle_operations(ret_list, is_encode=False) # bc 0.9.0 + # if not config.USE_META_DICT: + # return ret_list + # return pickle_operations(ret_list, is_encode=False) # bc 0.9.0 + return ret_list raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 16cb875d03..105b4f3a79 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False) + checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True) k, _ = list(self.load_dict.items())[0] # single object and checkpoint is directly a state_dict diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index c59436525c..e0eeef6001 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -64,8 +64,8 @@ def __init__( pickle_module: module used for pickling metadata and objects, default to `pickle`. this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. """ @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any: fn = self.cached[key]["obj"] # pytype: disable=attribute-error if not os.path.exists(fn): # pytype: disable=wrong-arg-types raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") - data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False) + data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True) # copy back to device if necessary if "device" in self.cached[key]: data_obj = data_obj.to(self.cached[key]["device"]) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 420e935b33..b5dfb580c5 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -117,6 +117,7 @@ def convert_to_tensor( wrap_sequence: bool = False, track_meta: bool = False, safe: bool = False, + convert_numeric: bool = True, ) -> Any: """ Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`, @@ -136,6 +137,7 @@ def convert_to_tensor( safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`. E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`. If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`. + convert_numeric: if `True`, convert numeric Python values to tensors. """ @@ -156,6 +158,7 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: if safe: data = safe_dtype_range(data, dtype) dtype = get_equivalent_dtype(dtype, torch.Tensor) + if isinstance(data, torch.Tensor): return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format) if isinstance(data, np.ndarray): @@ -167,16 +170,25 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: if data.ndim > 0: data = np.ascontiguousarray(data) return _convert_tensor(data, dtype=dtype, device=device) - elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): + elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))): return _convert_tensor(data, dtype=dtype, device=device) elif isinstance(data, list): - list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data] + list_ret = [ + convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for i in data + ] return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret elif isinstance(data, tuple): - tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data) + tuple_ret = tuple( + convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for i in data + ) return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret elif isinstance(data, dict): - return {k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta) for k, v in data.items()} + return { + k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for k, v in data.items() + } return data diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index c0e53fd24c..427902f784 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -245,7 +245,7 @@ def test_pickling(self): with tempfile.TemporaryDirectory() as tmp_dir: fname = os.path.join(tmp_dir, "im.pt") torch.save(m, fname) - m2 = torch.load(fname, weights_only=False) + m2 = torch.load(fname, weights_only=True) self.check(m2, m, ids=False) @skip_if_no_cuda diff --git a/tests/data/test_gdsdataset.py b/tests/data/test_gdsdataset.py index b4acb3bf55..aa802249bc 100644 --- a/tests/data/test_gdsdataset.py +++ b/tests/data/test_gdsdataset.py @@ -12,7 +12,6 @@ from __future__ import annotations import os -import pickle import tempfile import unittest @@ -86,7 +85,8 @@ def test_cache(self): cache_dir=tempdir, device=0, pickle_module="pickle", - pickle_protocol=pickle.HIGHEST_PROTOCOL, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + pickle_protocol=torch.serialization.DEFAULT_PROTOCOL, ) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0) diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index 7c4969e283..7bf1245592 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -12,12 +12,12 @@ from __future__ import annotations import os -import pickle import tempfile import unittest import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import PersistentDataset, json_hashing @@ -66,7 +66,8 @@ def test_cache(self): transform=_InplaceXform(), cache_dir=tempdir, pickle_module="pickle", - pickle_protocol=pickle.HIGHEST_PROTOCOL, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + pickle_protocol=torch.serialization.DEFAULT_PROTOCOL, ) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) diff --git a/tests/utils/test_state_cacher.py b/tests/utils/test_state_cacher.py index 22c2836239..6e6eabf03d 100644 --- a/tests/utils/test_state_cacher.py +++ b/tests/utils/test_state_cacher.py @@ -27,7 +27,13 @@ TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}] TEST_CASE_1 = [ torch.Tensor([1]).to(DEVICE), - {"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL}, + { + "in_memory": False, + "cache_dir": gettempdir(), + "pickle_module": None, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + "pickle_protocol": torch.serialization.DEFAULT_PROTOCOL, + }, ] TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}] TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}]