diff --git a/tests/integration_tests/test_zn_nodes2.py b/tests/integration_tests/test_zn_nodes2.py index cacf4984..8aa30884 100644 --- a/tests/integration_tests/test_zn_nodes2.py +++ b/tests/integration_tests/test_zn_nodes2.py @@ -1,8 +1,10 @@ +import dataclasses import os import shutil import subprocess import pytest +import znjson from zntrack import zn from zntrack.core.base import Node @@ -120,3 +122,47 @@ def test_NodeWithOuts(proj_path): node_1.write_graph(run=True) assert SingleExampleNode.load().params1.factor == 2 + + +@dataclasses.dataclass +class Parameter: + value: int = 0 + + +class NodeWithParameter(Node): + parameter = zn.params(Parameter()) + _hash = zn.Hash() + + +class MoreNode(Node): + node: NodeWithParameter = zn.Nodes() + + +class ParameterConverter(znjson.ConverterBase): + level = 100 + representation = "parameter" + instance = Parameter + + def _encode(self, obj: Parameter) -> dict: + return dataclasses.asdict(obj) + + def _decode(self, value: dict) -> Parameter: + return Parameter(**value) + + +def test_DataclassNode(proj_path): + znjson.register(ParameterConverter) + + node_w_params = NodeWithParameter(parameter=Parameter(value=42)) + node_w_params.write_graph() + + node = MoreNode(node=NodeWithParameter(parameter=Parameter(value=10))) + node.write_graph() + + node_w_params = node_w_params.load() + assert node_w_params.parameter.value == 42 + + node = node.load() + assert node.node.parameter.value == 10 + + znjson.deregister(ParameterConverter) diff --git a/zntrack/core/dvcgraph.py b/zntrack/core/dvcgraph.py index 03d46d17..062ec974 100644 --- a/zntrack/core/dvcgraph.py +++ b/zntrack/core/dvcgraph.py @@ -6,6 +6,8 @@ import pathlib import typing +import znjson + from zntrack import descriptor, utils from zntrack.core.jupyter import jupyter_class_to_file from zntrack.core.zntrackoption import ZnTrackOption @@ -273,7 +275,7 @@ def __hash__(self): params_dict = self.zntrack.collect(zn_params) params_dict["node_name"] = self.node_name - return hash(json.dumps(params_dict, sort_keys=True)) + return hash(json.dumps(params_dict, sort_keys=True, cls=znjson.ZnEncoder)) @property def _descriptor_list(self) -> typing.List[BaseDescriptorType]: