From 802bb60da5d150fd8948ef4875fd9e740e8162ab Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Wed, 16 Feb 2022 14:12:33 +0100 Subject: [PATCH] change serializer (#231) * change serializer * move definition out of __init__ * fix typo * keep `value` key in `zntrack.json` * add check for backwards compatibility * use fixture * review comments --- tests/integration_tests/test_zn_deps.py | 73 +++++++++ tests/integration_tests/test_zn_methods.py | 72 ++++++++- tests/unit_tests/zn/test_zn_options.py | 78 +++++++++ zntrack/utils/__init__.py | 2 + zntrack/utils/utils.py | 5 + zntrack/zn/__init__.py | 179 +-------------------- zntrack/zn/method.py | 57 +++++++ zntrack/zn/split_option.py | 153 ++++++++++++++++++ 8 files changed, 443 insertions(+), 176 deletions(-) create mode 100644 tests/integration_tests/test_zn_deps.py create mode 100644 zntrack/zn/method.py create mode 100644 zntrack/zn/split_option.py diff --git a/tests/integration_tests/test_zn_deps.py b/tests/integration_tests/test_zn_deps.py new file mode 100644 index 00000000..9e90e082 --- /dev/null +++ b/tests/integration_tests/test_zn_deps.py @@ -0,0 +1,73 @@ +import json +import os +import pathlib +import shutil +import subprocess + +import pytest + +from zntrack import zn +from zntrack.core.base import Node + + +@pytest.fixture +def proj_path(tmp_path): + shutil.copy(__file__, tmp_path) + os.chdir(tmp_path) + subprocess.check_call(["git", "init"]) + subprocess.check_call(["dvc", "init"]) + + return tmp_path + + +class FirstNode(Node): + outs = zn.outs() + + def run(self): + self.outs = 42 + + +class LastNode(Node): + first_node: FirstNode = zn.deps(FirstNode.load()) + outs = zn.outs() + + def run(self): + self.outs = self.first_node.outs / 2 + + +def test_base_run(proj_path): + FirstNode().write_graph(run=True) + LastNode().write_graph(run=True) + + assert LastNode.load().outs == 21 + + +@pytest.fixture() +def zntrack_dict() -> dict: + return { + "LastNode": { + "first_node": { + "_type": "ZnTrackType", + "value": { + "cls": "FirstNode", + "module": "test_zn_deps", + "name": "FirstNode", + }, + } + } + } + + +def test_assert_write_file(proj_path, zntrack_dict): + FirstNode().write_graph() + LastNode().write_graph() + + zntrack_dict_loaded = json.loads(pathlib.Path("zntrack.json").read_text()) + + assert zntrack_dict_loaded == zntrack_dict + + +def test_assert_read_file(proj_path, zntrack_dict): + pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict)) + + assert isinstance(LastNode.load().first_node, FirstNode) diff --git a/tests/integration_tests/test_zn_methods.py b/tests/integration_tests/test_zn_methods.py index 81f06e74..9fd90126 100644 --- a/tests/integration_tests/test_zn_methods.py +++ b/tests/integration_tests/test_zn_methods.py @@ -125,12 +125,13 @@ def test_created_files(proj_path): assert zntrack_dict["SingleNode"]["data_class"] == { "_type": "zn.method", - "value": { - "module": "test_zn_methods", - "cls": "ExampleMethod", - }, + "value": {"module": "test_zn_methods"}, + } + assert params_dict["SingleNode"]["data_class"] == { + "param1": 1, + "param2": 2, + "_cls": "ExampleMethod", } - assert params_dict["SingleNode"]["data_class"] == {"param1": 1, "param2": 2} class SingleNodeNoParams(Node): @@ -162,3 +163,64 @@ def test_write_params_no_kwargs(proj_path): dvc_dict = yaml.safe_load(pathlib.Path("dvc.yaml").read_text()) assert dvc_dict["stages"]["SingleNodeNoParams"]["params"] == ["SingleNodeNoParams"] + + +@pytest.fixture() +def zntrack_params_dict() -> (dict, dict): + zntrack_dict = { + "SingleNodeNoParams": { + "data_class": {"_type": "zn.method", "value": {"module": "test_zn_methods"}} + } + } + params_dict = { + "SingleNodeNoParams": { + "data_class": {"_cls": "ExampleMethod", "param1": 1, "param2": 2} + } + } + return zntrack_dict, params_dict + + +def test_assert_write_files(proj_path, zntrack_params_dict): + """Test the written files (without mocking pathlibs write_text)""" + SingleNodeNoParams(data_class=ExampleMethod(1, 2)).write_graph() + + zntrack_dict = json.loads(pathlib.Path("zntrack.json").read_text()) + params_dict = yaml.safe_load(pathlib.Path("params.yaml").read_text()) + + assert zntrack_dict == zntrack_params_dict[0] + assert params_dict == zntrack_params_dict[1] + + +def test_assert_read_files(proj_path, zntrack_params_dict): + """Test the written files (without mocking pathlibs write_text)""" + zntrack_dict, params_dict = zntrack_params_dict + + pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict)) + pathlib.Path("params.yaml").write_text(yaml.safe_dump(params_dict)) + + node = SingleNodeNoParams.load() + assert node.data_class.param1 == 1 + assert node.data_class.param2 == 2 + + +def test_assert_read_files_old1(proj_path): + """Test the written files (without mocking pathlibs write_text) + + Test for versions before https://github.com/zincware/ZnTrack/pull/231 + """ + zntrack_dict = { + "SingleNodeNoParams": { + "data_class": { + "_type": "zn.method", + "value": {"cls": "ExampleMethod", "module": "test_zn_methods"}, + } + } + } + params_dict = {"SingleNodeNoParams": {"data_class": {"param1": 1, "param2": 2}}} + + pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict)) + pathlib.Path("params.yaml").write_text(yaml.safe_dump(params_dict)) + + node = SingleNodeNoParams.load() + assert node.data_class.param1 == 1 + assert node.data_class.param2 == 2 diff --git a/tests/unit_tests/zn/test_zn_options.py b/tests/unit_tests/zn/test_zn_options.py index bbc0561d..ce813fcf 100644 --- a/tests/unit_tests/zn/test_zn_options.py +++ b/tests/unit_tests/zn/test_zn_options.py @@ -1,4 +1,11 @@ +import dataclasses +import json +import pathlib + +import znjson + from zntrack import zn +from zntrack.zn.split_option import combine_values, split_value class ExampleClass: @@ -30,3 +37,74 @@ def test_zn_plots(): # test save and load if there is nothing to save or load assert ExamplePlots.plots.save(example) is None assert ExamplePlots.plots.load(example) is None + + +@dataclasses.dataclass +class ExampleDataClass: + a: int = 5 + b: int = 7 + + # make it a zn.Method + znjson_zn_method = True + + def __eq__(self, other): + return (other.a == self.a) and (other.b == self.b) + + +def test_split_value(): + serialized_value = json.loads(json.dumps(ExampleDataClass(), cls=znjson.ZnEncoder)) + + params_data, zntrack_data = split_value(serialized_value) + assert zntrack_data == {"_type": "zn.method", "value": {"module": "test_zn_options"}} + assert params_data == {"_cls": "ExampleDataClass", "a": 5, "b": 7} + + # and now test the same thing but serialize a list + serialized_value = json.loads(json.dumps([ExampleDataClass()], cls=znjson.ZnEncoder)) + params_data, zntrack_data = split_value(serialized_value) + assert zntrack_data == [ + {"_type": "zn.method", "value": {"module": "test_zn_options"}} + ] + assert params_data == ({"_cls": "ExampleDataClass", "a": 5, "b": 7},) + + +def test_combine_values(): + zntrack_data = {"_type": "zn.method", "value": {"module": "test_zn_options"}} + params_data = {"_cls": "ExampleDataClass", "a": 5, "b": 7} + + assert combine_values(zntrack_data, params_data) == ExampleDataClass() + + # try older data structure + zntrack_data = { + "_type": "zn.method", + "value": { + "module": "test_zn_options", + "cls": "ExampleDataClass", + }, + } + params_data = {"a": 5, "b": 7} + assert combine_values(zntrack_data, params_data) == ExampleDataClass() + + # try older data structure + zntrack_data = { + "_type": "zn.method", + "value": { + "module": "test_zn_options", + "name": "ExampleDataClass", + }, + } + params_data = {"a": 5, "b": 7} + assert combine_values(zntrack_data, params_data) == ExampleDataClass() + + +def test_split_value_path(): + path = pathlib.Path("my_path") + serialized_value = json.loads(json.dumps(path, cls=znjson.ZnEncoder)) + + params_data, zntrack_data = split_value(serialized_value) + + assert params_data == "my_path" + assert zntrack_data == {"_type": "pathlib.Path"} + + new_path = combine_values(zntrack_data, params_data) + # TODO change order to be consistent with split_values + assert new_path == path diff --git a/zntrack/utils/__init__.py b/zntrack/utils/__init__.py index c69f493d..383fafb4 100644 --- a/zntrack/utils/__init__.py +++ b/zntrack/utils/__init__.py @@ -16,6 +16,7 @@ cwd_temp_dir, decode_dict, deprecated, + encode_dict, get_python_interpreter, module_handler, module_to_path, @@ -27,6 +28,7 @@ "config", "cwd_temp_dir", "decode_dict", + "encode_dict", "module_handler", "update_nb_name", "module_to_path", diff --git a/zntrack/utils/utils.py b/zntrack/utils/utils.py index 407139ad..3d637377 100644 --- a/zntrack/utils/utils.py +++ b/zntrack/utils/utils.py @@ -78,6 +78,11 @@ def decode_dict(value): return json.loads(json.dumps(value), cls=znjson.ZnDecoder) +def encode_dict(value) -> dict: + """Encode value into a dict serialized with ZnJson""" + return json.loads(json.dumps(value, cls=znjson.ZnEncoder)) + + def get_auto_init(fields: typing.List[str]): """Automatically create a __init__ based on fields Parameters diff --git a/zntrack/zn/__init__.py b/zntrack/zn/__init__.py index b7038ecb..a14127bf 100644 --- a/zntrack/zn/__init__.py +++ b/zntrack/zn/__init__.py @@ -13,21 +13,22 @@ see https://dvc.org/doc/command-reference/run#options """ -import json import logging -import znjson - from zntrack import utils -from zntrack.core.parameter import File, ZnTrackOption +from zntrack.core.parameter import ZnTrackOption from zntrack.descriptor import Metadata +from zntrack.zn.method import Method +from zntrack.zn.split_option import SplitZnTrackOption log = logging.getLogger(__name__) +__all__ = [Method.__name__] + try: from .plots import plots - __all__ = [plots.__name__] + __all__ += [plots.__name__] except ImportError: pass @@ -38,121 +39,6 @@ # for direct file references use dvc.