Skip to content

Commit

Permalink
fix hash with znjson serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed May 23, 2022
1 parent 7ad2c49 commit 2ddfc91
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
46 changes: 46 additions & 0 deletions tests/integration_tests/test_zn_nodes2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import os
import shutil
import subprocess

import pytest
import znjson

from zntrack import zn
from zntrack.core.base import Node
Expand Down Expand Up @@ -120,3 +122,47 @@ def test_NodeWithOuts(proj_path):
node_1.write_graph(run=True)

assert SingleExampleNode.load().params1.factor == 2


@dataclasses.dataclass
class Parameter:
value: int = 0


class NodeWithParameter(Node):
parameter = zn.params(Parameter())
_hash = zn.Hash()


class MoreNode(Node):
node: NodeWithParameter = zn.Nodes()


class ParameterConverter(znjson.ConverterBase):
level = 100
representation = "parameter"
instance = Parameter

def _encode(self, obj: Parameter) -> dict:
return dataclasses.asdict(obj)

def _decode(self, value: dict) -> Parameter:
return Parameter(**value)


def test_DataclassNode(proj_path):
znjson.register(ParameterConverter)

node_w_params = NodeWithParameter(parameter=Parameter(value=42))
node_w_params.write_graph()

node = MoreNode(node=NodeWithParameter(parameter=Parameter(value=10)))
node.write_graph()

node_w_params = node_w_params.load()
assert node_w_params.parameter.value == 42

node = node.load()
assert node.node.parameter.value == 10

znjson.deregister(ParameterConverter)
4 changes: 3 additions & 1 deletion zntrack/core/dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pathlib
import typing

import znjson

from zntrack import descriptor, utils
from zntrack.core.jupyter import jupyter_class_to_file
from zntrack.core.zntrackoption import ZnTrackOption
Expand Down Expand Up @@ -273,7 +275,7 @@ def __hash__(self):
params_dict = self.zntrack.collect(zn_params)
params_dict["node_name"] = self.node_name

return hash(json.dumps(params_dict, sort_keys=True))
return hash(json.dumps(params_dict, sort_keys=True, cls=znjson.ZnEncoder))

@property
def _descriptor_list(self) -> typing.List[BaseDescriptorType]:
Expand Down

0 comments on commit 2ddfc91

Please sign in to comment.