Skip to content

Commit

Permalink
change serializer (#231)
Browse files Browse the repository at this point in the history
* change serializer

* move definition out of __init__

* fix typo

* keep `value` key in `zntrack.json`

* add check for backwards compatibility

* use fixture

* review comments
  • Loading branch information
PythonFZ authored Feb 16, 2022
1 parent 58e96f4 commit 802bb60
Show file tree
Hide file tree
Showing 8 changed files with 443 additions and 176 deletions.
73 changes: 73 additions & 0 deletions tests/integration_tests/test_zn_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
import os
import pathlib
import shutil
import subprocess

import pytest

from zntrack import zn
from zntrack.core.base import Node


@pytest.fixture
def proj_path(tmp_path):
shutil.copy(__file__, tmp_path)
os.chdir(tmp_path)
subprocess.check_call(["git", "init"])
subprocess.check_call(["dvc", "init"])

return tmp_path


class FirstNode(Node):
outs = zn.outs()

def run(self):
self.outs = 42


class LastNode(Node):
first_node: FirstNode = zn.deps(FirstNode.load())
outs = zn.outs()

def run(self):
self.outs = self.first_node.outs / 2


def test_base_run(proj_path):
FirstNode().write_graph(run=True)
LastNode().write_graph(run=True)

assert LastNode.load().outs == 21


@pytest.fixture()
def zntrack_dict() -> dict:
return {
"LastNode": {
"first_node": {
"_type": "ZnTrackType",
"value": {
"cls": "FirstNode",
"module": "test_zn_deps",
"name": "FirstNode",
},
}
}
}


def test_assert_write_file(proj_path, zntrack_dict):
FirstNode().write_graph()
LastNode().write_graph()

zntrack_dict_loaded = json.loads(pathlib.Path("zntrack.json").read_text())

assert zntrack_dict_loaded == zntrack_dict


def test_assert_read_file(proj_path, zntrack_dict):
pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict))

assert isinstance(LastNode.load().first_node, FirstNode)
72 changes: 67 additions & 5 deletions tests/integration_tests/test_zn_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ def test_created_files(proj_path):

assert zntrack_dict["SingleNode"]["data_class"] == {
"_type": "zn.method",
"value": {
"module": "test_zn_methods",
"cls": "ExampleMethod",
},
"value": {"module": "test_zn_methods"},
}
assert params_dict["SingleNode"]["data_class"] == {
"param1": 1,
"param2": 2,
"_cls": "ExampleMethod",
}
assert params_dict["SingleNode"]["data_class"] == {"param1": 1, "param2": 2}


class SingleNodeNoParams(Node):
Expand Down Expand Up @@ -162,3 +163,64 @@ def test_write_params_no_kwargs(proj_path):

dvc_dict = yaml.safe_load(pathlib.Path("dvc.yaml").read_text())
assert dvc_dict["stages"]["SingleNodeNoParams"]["params"] == ["SingleNodeNoParams"]


@pytest.fixture()
def zntrack_params_dict() -> (dict, dict):
zntrack_dict = {
"SingleNodeNoParams": {
"data_class": {"_type": "zn.method", "value": {"module": "test_zn_methods"}}
}
}
params_dict = {
"SingleNodeNoParams": {
"data_class": {"_cls": "ExampleMethod", "param1": 1, "param2": 2}
}
}
return zntrack_dict, params_dict


def test_assert_write_files(proj_path, zntrack_params_dict):
"""Test the written files (without mocking pathlibs write_text)"""
SingleNodeNoParams(data_class=ExampleMethod(1, 2)).write_graph()

zntrack_dict = json.loads(pathlib.Path("zntrack.json").read_text())
params_dict = yaml.safe_load(pathlib.Path("params.yaml").read_text())

assert zntrack_dict == zntrack_params_dict[0]
assert params_dict == zntrack_params_dict[1]


def test_assert_read_files(proj_path, zntrack_params_dict):
"""Test the written files (without mocking pathlibs write_text)"""
zntrack_dict, params_dict = zntrack_params_dict

pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict))
pathlib.Path("params.yaml").write_text(yaml.safe_dump(params_dict))

node = SingleNodeNoParams.load()
assert node.data_class.param1 == 1
assert node.data_class.param2 == 2


def test_assert_read_files_old1(proj_path):
"""Test the written files (without mocking pathlibs write_text)
Test for versions before https://github.com/zincware/ZnTrack/pull/231
"""
zntrack_dict = {
"SingleNodeNoParams": {
"data_class": {
"_type": "zn.method",
"value": {"cls": "ExampleMethod", "module": "test_zn_methods"},
}
}
}
params_dict = {"SingleNodeNoParams": {"data_class": {"param1": 1, "param2": 2}}}

pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict))
pathlib.Path("params.yaml").write_text(yaml.safe_dump(params_dict))

node = SingleNodeNoParams.load()
assert node.data_class.param1 == 1
assert node.data_class.param2 == 2
78 changes: 78 additions & 0 deletions tests/unit_tests/zn/test_zn_options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import dataclasses
import json
import pathlib

import znjson

from zntrack import zn
from zntrack.zn.split_option import combine_values, split_value


class ExampleClass:
Expand Down Expand Up @@ -30,3 +37,74 @@ def test_zn_plots():
# test save and load if there is nothing to save or load
assert ExamplePlots.plots.save(example) is None
assert ExamplePlots.plots.load(example) is None


@dataclasses.dataclass
class ExampleDataClass:
a: int = 5
b: int = 7

# make it a zn.Method
znjson_zn_method = True

def __eq__(self, other):
return (other.a == self.a) and (other.b == self.b)


def test_split_value():
serialized_value = json.loads(json.dumps(ExampleDataClass(), cls=znjson.ZnEncoder))

params_data, zntrack_data = split_value(serialized_value)
assert zntrack_data == {"_type": "zn.method", "value": {"module": "test_zn_options"}}
assert params_data == {"_cls": "ExampleDataClass", "a": 5, "b": 7}

# and now test the same thing but serialize a list
serialized_value = json.loads(json.dumps([ExampleDataClass()], cls=znjson.ZnEncoder))
params_data, zntrack_data = split_value(serialized_value)
assert zntrack_data == [
{"_type": "zn.method", "value": {"module": "test_zn_options"}}
]
assert params_data == ({"_cls": "ExampleDataClass", "a": 5, "b": 7},)


def test_combine_values():
zntrack_data = {"_type": "zn.method", "value": {"module": "test_zn_options"}}
params_data = {"_cls": "ExampleDataClass", "a": 5, "b": 7}

assert combine_values(zntrack_data, params_data) == ExampleDataClass()

# try older data structure
zntrack_data = {
"_type": "zn.method",
"value": {
"module": "test_zn_options",
"cls": "ExampleDataClass",
},
}
params_data = {"a": 5, "b": 7}
assert combine_values(zntrack_data, params_data) == ExampleDataClass()

# try older data structure
zntrack_data = {
"_type": "zn.method",
"value": {
"module": "test_zn_options",
"name": "ExampleDataClass",
},
}
params_data = {"a": 5, "b": 7}
assert combine_values(zntrack_data, params_data) == ExampleDataClass()


def test_split_value_path():
path = pathlib.Path("my_path")
serialized_value = json.loads(json.dumps(path, cls=znjson.ZnEncoder))

params_data, zntrack_data = split_value(serialized_value)

assert params_data == "my_path"
assert zntrack_data == {"_type": "pathlib.Path"}

new_path = combine_values(zntrack_data, params_data)
# TODO change order to be consistent with split_values
assert new_path == path
2 changes: 2 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
cwd_temp_dir,
decode_dict,
deprecated,
encode_dict,
get_python_interpreter,
module_handler,
module_to_path,
Expand All @@ -27,6 +28,7 @@
"config",
"cwd_temp_dir",
"decode_dict",
"encode_dict",
"module_handler",
"update_nb_name",
"module_to_path",
Expand Down
5 changes: 5 additions & 0 deletions zntrack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def decode_dict(value):
return json.loads(json.dumps(value), cls=znjson.ZnDecoder)


def encode_dict(value) -> dict:
"""Encode value into a dict serialized with ZnJson"""
return json.loads(json.dumps(value, cls=znjson.ZnEncoder))


def get_auto_init(fields: typing.List[str]):
"""Automatically create a __init__ based on fields
Parameters
Expand Down
Loading

0 comments on commit 802bb60

Please sign in to comment.