From 351d4d951adb92500429168172acb5e89032c0c1 Mon Sep 17 00:00:00 2001 From: yusuke kunimitsu <> Date: Tue, 11 Nov 2025 12:23:55 +0900 Subject: [PATCH] Add safetensors support for PyTorch --- pyproject.toml | 3 + .../pytorch_module_materializer.py | 249 +++++++++++++++--- .../test_pytorch_module_materializer.py | 103 ++++++++ 3 files changed, 324 insertions(+), 31 deletions(-) create mode 100644 tests/unit/integrations/pytorch/materializers/test_pytorch_module_materializer.py diff --git a/pyproject.toml b/pyproject.toml index 087a125b6b0..7c398a2254b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,9 @@ Issues = "https://github.com/zenml-io/zenml/issues" zenml = "zenml.cli.cli:cli" [project.optional-dependencies] +safetensors = [ + "safetensors>=0.4.0" +] local = [ "alembic>=1.8.1,<=1.15.2", "bcrypt==4.0.1", diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py index 86e36572c8d..e89d47df032 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py @@ -11,12 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Implementation of the PyTorch Module materializer.""" +"""Implementation of the PyTorch Module materializer (with SafeTensors).""" +from __future__ import annotations + +import importlib +import json import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +import tempfile +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Optional, + Tuple, + Type, + cast, +) -import cloudpickle +import cloudpickle # for legacy entire_model.pt compatibility import torch from torch.nn import Module @@ -25,54 +39,227 @@ BasePyTorchMaterializer, ) from zenml.integrations.pytorch.utils import count_module_params -from zenml.io import fileio +from zenml.logger import get_logger +from zenml.utils.io_utils import copy_dir if TYPE_CHECKING: from zenml.metadata.metadata_types import MetadataType +logger = get_logger(__name__) + +# Legacy names kept for compatibility (we no longer write these in Phase 1). DEFAULT_FILENAME = "entire_model.pt" CHECKPOINT_FILENAME = "checkpoint.pt" +# New filenames for Phase 1 +WEIGHTS_SAFE = "weights.safetensors" +WEIGHTS_PT = "weights.pt" +META_FILE = "metadata.json" -class PyTorchModuleMaterializer(BasePyTorchMaterializer): - """Materializer to read/write Pytorch models. - Inspired by the guide: - https://pytorch.org/tutorials/beginner/saving_loading_models.html +def _import_from_path(class_path: str) -> Type[Any]: + """Import a class from 'pkg.module:Class' or 'pkg.module.Class'. + + Args: + class_path: Fully qualified class path. + + Returns: + The imported class. """ + if ":" in class_path: + mod, cls = class_path.split(":") + else: + parts = class_path.split(".") + mod, cls = ".".join(parts[:-1]), parts[-1] + module = importlib.import_module(mod) + return cast(Type[Any], getattr(module, cls)) + + +def _has_safetensors() -> bool: + try: + import safetensors.torch # noqa: F401 + + return True + except Exception: + return False + + +class PyTorchModuleMaterializer(BasePyTorchMaterializer): + """Materializer to read/write PyTorch models.""" ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL + # NOTE: FILENAME is unused in Phase 1 (we don't write entire_model.pt) FILENAME: ClassVar[str] = DEFAULT_FILENAME + # --------------------------- + # Save + # --------------------------- def save(self, model: Module) -> None: - """Writes a PyTorch model, as a model and a checkpoint. + """Save a PyTorch model as state_dict + metadata. - Args: - model: A torch.nn.Module or a dict to pass into model.save + Why: prefer non-pickle format (safetensors) for security and speed. + If safetensors is missing, fall back to .pt (pickle) with a warning. + Use a local temp dir then copy to artifact store to satisfy PathLike I/O. """ - # Save entire model to artifact directory, This is the default behavior - # for loading model in development phase (training, evaluation) - super().save(model) - - # Also save model checkpoint to artifact directory, - # This is the default behavior for loading model in production phase (inference) - if isinstance(model, Module): - with fileio.open( - os.path.join(self.uri, CHECKPOINT_FILENAME), "wb" - ) as f: - # NOTE (security): The `torch.save` function uses `cloudpickle` as - # the default unpickler, which is NOT secure. This materializer - # is intended for use with trusted data sources. - torch.save(model.state_dict(), f, pickle_module=cloudpickle) # nosec + # Build a CPU-mapped state_dict to ensure portability across devices. + state_dict = { + k: v.detach().cpu() for k, v in model.state_dict().items() + } - def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]: - """Extract metadata from the given `Model` object. + with tempfile.TemporaryDirectory() as tmp: + tmp_dir = os.path.abspath(tmp) + + # Write weights + if _has_safetensors(): + from safetensors.torch import save_file + + weights_path = os.path.join(tmp_dir, WEIGHTS_SAFE) + save_file(state_dict, weights_path) + fmt = "safetensors" + else: + logger.warning( + "safetensors not installed; falling back to pickle (.pt). " + "Install with: pip install 'zenml[safetensors]'" + ) + weights_path = os.path.join(tmp_dir, WEIGHTS_PT) + # NOTE: torch.save uses pickle; fallback intended for trusted environments. + torch.save(state_dict, weights_path, pickle_module=cloudpickle) # nosec + fmt = "pickle" + + # Write minimal metadata (extensible in future phases) + meta: Dict[str, Any] = { + "class_path": f"{model.__class__.__module__}.{model.__class__.__name__}", + "serialization_format": fmt, + # Reserved (Phase 2+) + "init_args": [], + "init_kwargs": {}, + "factory_path": None, + } + with open(os.path.join(tmp_dir, META_FILE), "w") as f: + json.dump(meta, f, indent=2) + + # Upload directory to artifact store (remote-safe) + copy_dir(tmp_dir, self.uri) - Args: - model: The `Model` object to extract metadata from. + # IMPORTANT: + # We intentionally do NOT call super().save(model) nor write CHECKPOINT_FILENAME + # to avoid duplicate / insecure pickle artifacts in Phase 1. - Returns: - The extracted metadata as a dictionary. + # --------------------------- + # Load + # --------------------------- + def load(self, data_type: Optional[Type[Module]] = None) -> Module: + """Load a PyTorch model and always return nn.Module. + + Rules: + - Phase 1: require zero-arg constructor when metadata is present. + - Legacy (.pt only, no metadata): require `data_type` to be provided. + - Raise clear exceptions when reconstruction is not possible. """ + with tempfile.TemporaryDirectory() as tmp: + tmp_dir = os.path.abspath(tmp) + + # Mirror artifact dir locally for PathLike-only APIs + copy_dir(self.uri, tmp_dir) + + meta_path = os.path.join(tmp_dir, META_FILE) + has_meta = os.path.exists(meta_path) + + if has_meta: + # New format (metadata-driven) + with open(meta_path, "r") as f: + meta = json.load(f) + + class_path = meta.get("class_path") + if not class_path: + raise ValueError("metadata.json missing 'class_path'") + + try: + model_class = cast( + Type[Module], _import_from_path(class_path) + ) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Cannot import model class '{class_path}': {e}. " + "Ensure the module/package is available in PYTHONPATH." + ) from e + + fmt = meta.get("serialization_format", "pickle") + else: + legacy_entire_model = os.path.join(tmp_dir, DEFAULT_FILENAME) + if os.path.exists(legacy_entire_model): + model = torch.load( + legacy_entire_model, + map_location="cpu", + weights_only=False, + ) + if not isinstance(model, Module): + raise RuntimeError( + f"Legacy file {DEFAULT_FILENAME} did not contain " + "a torch.nn.Module." + ) + model.eval() + return model + + logger.warning( + "No metadata.json found. Loading legacy artifact: " + "using `data_type` parameter for model class." + ) + if data_type is None: + raise FileNotFoundError( + "Legacy artifact without metadata.json requires " + "`data_type` (the model class) to reconstruct." + ) + model_class = data_type + fmt = "pickle" # legacy assumption + + # Load weights or legacy entire model + if fmt == "safetensors": + safetensors_path = os.path.join(tmp_dir, WEIGHTS_SAFE) + if not os.path.exists(safetensors_path): + raise FileNotFoundError(f"Expected {safetensors_path} not found.") + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError( + "This artifact was saved with 'safetensors', but the optional " + "dependency is not installed. Install via:\n" + " pip install 'zenml[safetensors]'" + ) from e + state_dict = load_file(safetensors_path) + else: + # pickle/pt - state_dict checkpoints + pt_candidates = [ + os.path.join(tmp_dir, WEIGHTS_PT), + os.path.join(tmp_dir, CHECKPOINT_FILENAME), + ] + pt_path = next((p for p in pt_candidates if os.path.exists(p)), None) + if not pt_path: + raise FileNotFoundError(f"Expected one of {pt_candidates} not found.") + state_dict = torch.load(pt_path, map_location="cpu") + + # Reconstruct model (Phase 1: zero-arg) + try: + model = model_class() + except TypeError as e: + raise RuntimeError( + f"Failed to instantiate {model_class.__name__}: {e}. " + "Phase 1 supports only zero-argument __init__(). " + "Use a factory or wait for Phase 2 enhancement." + ) from e + + try: + model.load_state_dict(state_dict, strict=True) + except Exception as e: + raise RuntimeError(f"Failed to load state_dict: {e}") from e + + model.eval() + return model + + # --------------------------- + # Metadata extraction (unchanged) + # --------------------------- + def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]: + """Extract metadata from the given `Model` object.""" return {**count_module_params(model)} diff --git a/tests/unit/integrations/pytorch/materializers/test_pytorch_module_materializer.py b/tests/unit/integrations/pytorch/materializers/test_pytorch_module_materializer.py new file mode 100644 index 00000000000..423aeb1f0b1 --- /dev/null +++ b/tests/unit/integrations/pytorch/materializers/test_pytorch_module_materializer.py @@ -0,0 +1,103 @@ +"""Tests for the PyTorch module materializer.""" + +from pathlib import Path + +import pytest +import torch +from torch import nn + +from zenml.integrations.pytorch.materializers.pytorch_module_materializer import ( + DEFAULT_FILENAME, + META_FILE, + WEIGHTS_PT, + WEIGHTS_SAFE, + PyTorchModuleMaterializer, +) + + +def _has_safetensors() -> bool: + """Return True if the optional safetensors dependency is installed.""" + try: + import safetensors.torch # noqa + + return True + except Exception: + return False + + +class Tiny(nn.Module): + """Single linear layer used to validate serialization round-trips.""" + + def __init__(self) -> None: + """Initialize the linear layer used for tests.""" + super().__init__() + self.fc = nn.Linear(4, 3) + + +@pytest.mark.skipif(not _has_safetensors(), reason="safetensors is optional") +def test_round_trip_with_safetensors(tmp_path: Path) -> None: + """End-to-end save/load using safetensors artifacts.""" + m = Tiny() + mat = PyTorchModuleMaterializer(uri=str(tmp_path)) + mat.save(m) + + # new artifacts exist + assert (tmp_path / META_FILE).exists() + assert (tmp_path / WEIGHTS_SAFE).exists() + # no legacy outputs + assert not (tmp_path / "checkpoint.pt").exists() + assert not (tmp_path / "entire_model.pt").exists() + assert not (tmp_path / WEIGHTS_PT).exists() + + loaded = mat.load() + assert isinstance(loaded, nn.Module) + assert isinstance(loaded, Tiny) + + +def test_legacy_weights_pt_without_metadata(tmp_path: Path) -> None: + """Legacy state_dict artifacts require passing data_type.""" + # simulate old artifact: weights.pt only, no metadata + state_dict = Tiny().state_dict() + torch.save(state_dict, tmp_path / WEIGHTS_PT) + mat = PyTorchModuleMaterializer(uri=str(tmp_path)) + + # must fail without data_type + with pytest.raises(FileNotFoundError): + mat.load() + + # works with data_type + loaded = mat.load(data_type=Tiny) + assert isinstance(loaded, Tiny) + + +def test_legacy_entire_model_pt_direct_load(tmp_path: Path) -> None: + """Very old pickle artifacts still deserialize to nn.Module.""" + # simulate very old artifact: entire_model.pt (pickled Module) + torch.save(Tiny(), tmp_path / DEFAULT_FILENAME) + mat = PyTorchModuleMaterializer(uri=str(tmp_path)) + loaded = mat.load() # metadata-less path returns Module directly + assert isinstance(loaded, Tiny) + + +def test_missing_safetensors_dependency_raises_clear_error( + tmp_path: Path, monkeypatch +) -> None: + """Missing dependency should produce actionable ImportError.""" + if not _has_safetensors(): + pytest.skip( + "safetensors not installed; scenario covered by other tests" + ) + + # create safetensors artifact + m = Tiny() + mat = PyTorchModuleMaterializer(uri=str(tmp_path)) + mat.save(m) + + # simulate missing dependency at load time + import sys + + monkeypatch.setitem(sys.modules, "safetensors", None) + monkeypatch.setitem(sys.modules, "safetensors.torch", None) + with pytest.raises(ImportError) as e: + mat.load() + assert "pip install 'zenml[safetensors]'" in str(e.value)