Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ Issues = "https://github.com/zenml-io/zenml/issues"
zenml = "zenml.cli.cli:cli"

[project.optional-dependencies]
safetensors = [
"safetensors>=0.4.0"
]
local = [
"alembic>=1.8.1,<=1.15.2",
"bcrypt==4.0.1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Implementation of the PyTorch Module materializer."""
"""Implementation of the PyTorch Module materializer (with SafeTensors)."""

from __future__ import annotations

import importlib
import json
import os
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type
import tempfile
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Optional,
Tuple,
Type,
cast,
)

import cloudpickle
import cloudpickle # for legacy entire_model.pt compatibility
import torch
from torch.nn import Module

Expand All @@ -25,54 +39,227 @@
BasePyTorchMaterializer,
)
from zenml.integrations.pytorch.utils import count_module_params
from zenml.io import fileio
from zenml.logger import get_logger
from zenml.utils.io_utils import copy_dir

if TYPE_CHECKING:
from zenml.metadata.metadata_types import MetadataType

logger = get_logger(__name__)

# Legacy names kept for compatibility (we no longer write these in Phase 1).
DEFAULT_FILENAME = "entire_model.pt"
CHECKPOINT_FILENAME = "checkpoint.pt"

# New filenames for Phase 1
WEIGHTS_SAFE = "weights.safetensors"
WEIGHTS_PT = "weights.pt"
META_FILE = "metadata.json"

class PyTorchModuleMaterializer(BasePyTorchMaterializer):
"""Materializer to read/write Pytorch models.

Inspired by the guide:
https://pytorch.org/tutorials/beginner/saving_loading_models.html
def _import_from_path(class_path: str) -> Type[Any]:
"""Import a class from 'pkg.module:Class' or 'pkg.module.Class'.

Args:
class_path: Fully qualified class path.

Returns:
The imported class.
"""
if ":" in class_path:
mod, cls = class_path.split(":")
else:
parts = class_path.split(".")
mod, cls = ".".join(parts[:-1]), parts[-1]
module = importlib.import_module(mod)
return cast(Type[Any], getattr(module, cls))


def _has_safetensors() -> bool:
try:
import safetensors.torch # noqa: F401

return True
except Exception:
return False


class PyTorchModuleMaterializer(BasePyTorchMaterializer):
"""Materializer to read/write PyTorch models."""

ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,)
ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL
# NOTE: FILENAME is unused in Phase 1 (we don't write entire_model.pt)
FILENAME: ClassVar[str] = DEFAULT_FILENAME

# ---------------------------
# Save
# ---------------------------
def save(self, model: Module) -> None:
"""Writes a PyTorch model, as a model and a checkpoint.
"""Save a PyTorch model as state_dict + metadata.

Args:
model: A torch.nn.Module or a dict to pass into model.save
Why: prefer non-pickle format (safetensors) for security and speed.
If safetensors is missing, fall back to .pt (pickle) with a warning.
Use a local temp dir then copy to artifact store to satisfy PathLike I/O.
"""
# Save entire model to artifact directory, This is the default behavior
# for loading model in development phase (training, evaluation)
super().save(model)

# Also save model checkpoint to artifact directory,
# This is the default behavior for loading model in production phase (inference)
if isinstance(model, Module):
with fileio.open(
os.path.join(self.uri, CHECKPOINT_FILENAME), "wb"
) as f:
# NOTE (security): The `torch.save` function uses `cloudpickle` as
# the default unpickler, which is NOT secure. This materializer
# is intended for use with trusted data sources.
torch.save(model.state_dict(), f, pickle_module=cloudpickle) # nosec
# Build a CPU-mapped state_dict to ensure portability across devices.
state_dict = {
k: v.detach().cpu() for k, v in model.state_dict().items()
}

def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given `Model` object.
with tempfile.TemporaryDirectory() as tmp:
tmp_dir = os.path.abspath(tmp)

# Write weights
if _has_safetensors():
from safetensors.torch import save_file

weights_path = os.path.join(tmp_dir, WEIGHTS_SAFE)
save_file(state_dict, weights_path)
fmt = "safetensors"
else:
logger.warning(
"safetensors not installed; falling back to pickle (.pt). "
"Install with: pip install 'zenml[safetensors]'"
)
weights_path = os.path.join(tmp_dir, WEIGHTS_PT)
# NOTE: torch.save uses pickle; fallback intended for trusted environments.
torch.save(state_dict, weights_path, pickle_module=cloudpickle) # nosec
fmt = "pickle"

# Write minimal metadata (extensible in future phases)
meta: Dict[str, Any] = {
"class_path": f"{model.__class__.__module__}.{model.__class__.__name__}",
"serialization_format": fmt,
# Reserved (Phase 2+)
"init_args": [],
"init_kwargs": {},
"factory_path": None,
}
with open(os.path.join(tmp_dir, META_FILE), "w") as f:
json.dump(meta, f, indent=2)

# Upload directory to artifact store (remote-safe)
copy_dir(tmp_dir, self.uri)

Args:
model: The `Model` object to extract metadata from.
# IMPORTANT:
# We intentionally do NOT call super().save(model) nor write CHECKPOINT_FILENAME
# to avoid duplicate / insecure pickle artifacts in Phase 1.

Returns:
The extracted metadata as a dictionary.
# ---------------------------
# Load
# ---------------------------
def load(self, data_type: Optional[Type[Module]] = None) -> Module:
"""Load a PyTorch model and always return nn.Module.

Rules:
- Phase 1: require zero-arg constructor when metadata is present.
- Legacy (.pt only, no metadata): require `data_type` to be provided.
- Raise clear exceptions when reconstruction is not possible.
"""
with tempfile.TemporaryDirectory() as tmp:
tmp_dir = os.path.abspath(tmp)

# Mirror artifact dir locally for PathLike-only APIs
copy_dir(self.uri, tmp_dir)

meta_path = os.path.join(tmp_dir, META_FILE)
has_meta = os.path.exists(meta_path)

if has_meta:
# New format (metadata-driven)
with open(meta_path, "r") as f:
meta = json.load(f)

class_path = meta.get("class_path")
if not class_path:
raise ValueError("metadata.json missing 'class_path'")

try:
model_class = cast(
Type[Module], _import_from_path(class_path)
)
except (ImportError, AttributeError) as e:
raise ImportError(
f"Cannot import model class '{class_path}': {e}. "
"Ensure the module/package is available in PYTHONPATH."
) from e

fmt = meta.get("serialization_format", "pickle")
else:
legacy_entire_model = os.path.join(tmp_dir, DEFAULT_FILENAME)
if os.path.exists(legacy_entire_model):
model = torch.load(
legacy_entire_model,
map_location="cpu",
weights_only=False,
)
if not isinstance(model, Module):
raise RuntimeError(
f"Legacy file {DEFAULT_FILENAME} did not contain "
"a torch.nn.Module."
)
model.eval()
return model

logger.warning(
"No metadata.json found. Loading legacy artifact: "
"using `data_type` parameter for model class."
)
if data_type is None:
raise FileNotFoundError(
"Legacy artifact without metadata.json requires "
"`data_type` (the model class) to reconstruct."
)
model_class = data_type
fmt = "pickle" # legacy assumption

# Load weights or legacy entire model
if fmt == "safetensors":
safetensors_path = os.path.join(tmp_dir, WEIGHTS_SAFE)
if not os.path.exists(safetensors_path):
raise FileNotFoundError(f"Expected {safetensors_path} not found.")
try:
from safetensors.torch import load_file
except ImportError as e:
raise ImportError(
"This artifact was saved with 'safetensors', but the optional "
"dependency is not installed. Install via:\n"
" pip install 'zenml[safetensors]'"
) from e
state_dict = load_file(safetensors_path)
else:
# pickle/pt - state_dict checkpoints
pt_candidates = [
os.path.join(tmp_dir, WEIGHTS_PT),
os.path.join(tmp_dir, CHECKPOINT_FILENAME),
]
pt_path = next((p for p in pt_candidates if os.path.exists(p)), None)
if not pt_path:
raise FileNotFoundError(f"Expected one of {pt_candidates} not found.")
state_dict = torch.load(pt_path, map_location="cpu")

# Reconstruct model (Phase 1: zero-arg)
try:
model = model_class()
except TypeError as e:
raise RuntimeError(
f"Failed to instantiate {model_class.__name__}: {e}. "
"Phase 1 supports only zero-argument __init__(). "
"Use a factory or wait for Phase 2 enhancement."
) from e

try:
model.load_state_dict(state_dict, strict=True)
except Exception as e:
raise RuntimeError(f"Failed to load state_dict: {e}") from e

model.eval()
return model

# ---------------------------
# Metadata extraction (unchanged)
# ---------------------------
def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]:
"""Extract metadata from the given `Model` object."""
return {**count_module_params(model)}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Tests for the PyTorch module materializer."""

from pathlib import Path

import pytest
import torch
from torch import nn

from zenml.integrations.pytorch.materializers.pytorch_module_materializer import (
DEFAULT_FILENAME,
META_FILE,
WEIGHTS_PT,
WEIGHTS_SAFE,
PyTorchModuleMaterializer,
)


def _has_safetensors() -> bool:
"""Return True if the optional safetensors dependency is installed."""
try:
import safetensors.torch # noqa

return True
except Exception:
return False


class Tiny(nn.Module):
"""Single linear layer used to validate serialization round-trips."""

def __init__(self) -> None:
"""Initialize the linear layer used for tests."""
super().__init__()
self.fc = nn.Linear(4, 3)


@pytest.mark.skipif(not _has_safetensors(), reason="safetensors is optional")
def test_round_trip_with_safetensors(tmp_path: Path) -> None:
"""End-to-end save/load using safetensors artifacts."""
m = Tiny()
mat = PyTorchModuleMaterializer(uri=str(tmp_path))
mat.save(m)

# new artifacts exist
assert (tmp_path / META_FILE).exists()
assert (tmp_path / WEIGHTS_SAFE).exists()
# no legacy outputs
assert not (tmp_path / "checkpoint.pt").exists()
assert not (tmp_path / "entire_model.pt").exists()
assert not (tmp_path / WEIGHTS_PT).exists()

loaded = mat.load()
assert isinstance(loaded, nn.Module)
assert isinstance(loaded, Tiny)


def test_legacy_weights_pt_without_metadata(tmp_path: Path) -> None:
"""Legacy state_dict artifacts require passing data_type."""
# simulate old artifact: weights.pt only, no metadata
state_dict = Tiny().state_dict()
torch.save(state_dict, tmp_path / WEIGHTS_PT)
mat = PyTorchModuleMaterializer(uri=str(tmp_path))

# must fail without data_type
with pytest.raises(FileNotFoundError):
mat.load()

# works with data_type
loaded = mat.load(data_type=Tiny)
assert isinstance(loaded, Tiny)


def test_legacy_entire_model_pt_direct_load(tmp_path: Path) -> None:
"""Very old pickle artifacts still deserialize to nn.Module."""
# simulate very old artifact: entire_model.pt (pickled Module)
torch.save(Tiny(), tmp_path / DEFAULT_FILENAME)
mat = PyTorchModuleMaterializer(uri=str(tmp_path))
loaded = mat.load() # metadata-less path returns Module directly
assert isinstance(loaded, Tiny)


def test_missing_safetensors_dependency_raises_clear_error(
tmp_path: Path, monkeypatch
) -> None:
"""Missing dependency should produce actionable ImportError."""
if not _has_safetensors():
pytest.skip(
"safetensors not installed; scenario covered by other tests"
)

# create safetensors artifact
m = Tiny()
mat = PyTorchModuleMaterializer(uri=str(tmp_path))
mat.save(m)

# simulate missing dependency at load time
import sys

monkeypatch.setitem(sys.modules, "safetensors", None)
monkeypatch.setitem(sys.modules, "safetensors.torch", None)
with pytest.raises(ImportError) as e:
mat.load()
assert "pip install 'zenml[safetensors]'" in str(e.value)