From 3710d2f589a6b2891266a81d8f35b7fc4edc1f3f Mon Sep 17 00:00:00 2001 From: NeelKondapalli <107438832+NeelKondapalli@users.noreply.github.com> Date: Sat, 24 Aug 2024 08:57:51 -0700 Subject: [PATCH] Use `weights_only=True` from PyTorch 2.4 (#423) Fixes #422 by adding the `weights_only = True` argument to `torch.load` in the file `io.py`. This protects agains the arbitrary data warning. The types `stype` and `StatType` were added to the safe globals list. By: Neel Kondapalli (neel2h06@gmail.com) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta --- CHANGELOG.md | 2 ++ torch_frame/__init__.py | 16 +++++++++++++++- torch_frame/typing.py | 4 ++++ torch_frame/utils/io.py | 7 ++++--- 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bbd9232..707ba8d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Set `weights_only=True` in `torch_frame.load` from PyTorch 2.4 ([#423](https://github.com/pyg-team/pytorch-frame/pull/423)) + ### Deprecated ### Removed diff --git a/torch_frame/__init__.py b/torch_frame/__init__.py index 7aa3f990..7161acc3 100644 --- a/torch_frame/__init__.py +++ b/torch_frame/__init__.py @@ -12,13 +12,27 @@ embedding, ) from .data import TensorFrame -from .typing import TaskType, Metric, DataFrame, NAStrategy +from .typing import ( + TaskType, + Metric, + DataFrame, + NAStrategy, + WITH_PT24, +) from torch_frame.utils import save, load, cat # noqa import torch_frame.data # noqa import torch_frame.datasets # noqa import torch_frame.nn # noqa import torch_frame.gbdt # noqa +if WITH_PT24: + import torch + + torch.serialization.add_safe_globals([ + stype, + torch_frame.data.stats.StatType, + ]) + __version__ = '0.2.3' __all__ = [ diff --git a/torch_frame/typing.py b/torch_frame/typing.py index a2e49159..c7aede63 100644 --- a/torch_frame/typing.py +++ b/torch_frame/typing.py @@ -4,11 +4,15 @@ from typing import Dict, List, Mapping, Union import pandas as pd +import torch from torch import Tensor from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor from torch_frame.data.multi_nested_tensor import MultiNestedTensor +WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2 +WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4 + class Metric(Enum): r"""The metric. diff --git a/torch_frame/utils/io.py b/torch_frame/utils/io.py index 5a22f3e8..50732d4a 100644 --- a/torch_frame/utils/io.py +++ b/torch_frame/utils/io.py @@ -13,7 +13,7 @@ ) from torch_frame.data.multi_tensor import _MultiTensor from torch_frame.data.stats import StatType -from torch_frame.typing import TensorData +from torch_frame.typing import WITH_PT24, TensorData def serialize_feat_dict( @@ -80,7 +80,8 @@ def save(tensor_frame: TensorFrame, def load( - path: str, device: torch.device | None = None + path: str, + device: torch.device | None = None, ) -> tuple[TensorFrame, dict[str, dict[StatType, Any]] | None]: r"""Load saved :class:`TensorFrame` object and optional :obj:`col_stats` from a specified path. @@ -95,7 +96,7 @@ def load( tuple: A tuple of loaded :class:`TensorFrame` object and optional :obj:`col_stats`. """ - tf_dict, col_stats = torch.load(path) + tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24) tf_dict['feat_dict'] = deserialize_feat_dict( tf_dict.pop('feat_serialized_dict')) tensor_frame = TensorFrame(**tf_dict)