Skip to content

Commit

Permalink
Undo call cargs change (#402)
Browse files Browse the repository at this point in the history
* reverse change in load vs getitem

* bump znjson to "^0.2.1"

* add unit test for write_graph
  • Loading branch information
PythonFZ authored Oct 6, 2022
1 parent 1803cb5 commit f3b97ff
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 441 deletions.
793 changes: 365 additions & 428 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"
python = ">=3.8,<4.0.0"
dvc = "^2.12.0"
pyyaml = "^6.0"
znjson = "^0.1.2"
znjson = "^0.2.1"
dot4dict = "^0.1.1"
tqdm = "^4.64.0"
pandas = "^1.4.3"
Expand Down
28 changes: 28 additions & 0 deletions tests/unit_tests/core/test_core_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,31 @@ def test_matmul_not_supported():
RunTestNode() @ "outs"

assert mock.call_count == 3


def test_write_graph():
example = ExampleDVCOutsNode()

with patch.object(ExampleDVCOutsNode, "save") as save_mock, patch.object(
ExampleDVCOutsNode, "_handle_nodes_as_methods"
) as handle_znnodes_mock:
# Patch the methods that write to disk
script = example.write_graph(dry_run=True)

assert save_mock.called
assert handle_znnodes_mock.called

assert script == [
"dvc",
"stage",
"add",
"-n",
"ExampleDVCOutsNode",
"--force",
"--outs",
"example.dat",
(
'python3 -c "from test_core_base import ExampleDVCOutsNode; '
"ExampleDVCOutsNode.load(name='ExampleDVCOutsNode').run_and_save()\" "
),
]
2 changes: 1 addition & 1 deletion zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from zntrack.zn.dependencies import getdeps

# register converters
znjson.register([ZnTrackTypeConverter, MethodConverter, NodeAttributeConverter])
znjson.config.register([ZnTrackTypeConverter, MethodConverter, NodeAttributeConverter])

__all__ = [
Node.__name__,
Expand Down
2 changes: 1 addition & 1 deletion zntrack/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def write_graph(
custom_args += pair

if call_args is None:
call_args = f"['{self.node_name}'].run_and_save()"
call_args = f".load(name='{self.node_name}').run_and_save()"

script = prepare_dvc_script(
node_name=self.node_name,
Expand Down
12 changes: 6 additions & 6 deletions zntrack/utils/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ZnTrackTypeConverter(znjson.ConverterBase):
representation = "ZnTrackType"
level = 10

def _encode(self, obj: Node) -> dict:
def encode(self, obj: Node) -> dict:
"""Convert Node to serializable dict"""
return dataclasses.asdict(
SerializedNode(
Expand All @@ -77,7 +77,7 @@ def _encode(self, obj: Node) -> dict:
)
)

def _decode(self, value: dict) -> Node:
def decode(self, value: dict) -> Node:
"""return serialized Node"""

serialized_node = SerializedNode(**value)
Expand All @@ -94,11 +94,11 @@ class NodeAttributeConverter(znjson.ConverterBase):
representation = "NodeAttribute"
level = 10

def _encode(self, obj: NodeAttribute) -> dict:
def encode(self, obj: NodeAttribute) -> dict:
"""Convert NodeAttribute to serializable dict"""
return dataclasses.asdict(obj)

def _decode(self, value: dict):
def decode(self, value: dict):
"""return serialized Node attribute"""

node_attribute = NodeAttribute(**value)
Expand All @@ -115,7 +115,7 @@ class MethodConverter(znjson.ConverterBase):
representation = "zn.method"
level = 10

def _encode(self, obj):
def encode(self, obj):
"""Serialize the object"""

serialized_method = SerializedMethod(
Expand Down Expand Up @@ -147,7 +147,7 @@ def _encode(self, obj):

return dataclasses.asdict(serialized_method)

def _decode(self, value: dict):
def decode(self, value: dict):
"""Deserialize the object"""

if "name" in value:
Expand Down
8 changes: 4 additions & 4 deletions zntrack/zn/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class RawNodeAttributeConverter(znjson.ConverterBase):
representation = "NodeAttribute"
level = 999

def _encode(self, obj: NodeAttribute) -> dict:
def encode(self, obj: NodeAttribute) -> dict:
"""Convert NodeAttribute to serializable dict"""
return dataclasses.asdict(obj)

def _decode(self, value: dict) -> NodeAttribute:
def decode(self, value: dict) -> NodeAttribute:
"""return serialized Node attribute"""
return NodeAttribute(**value)

Expand Down Expand Up @@ -77,11 +77,11 @@ def get_origin(
------
AttributeError: if the attribute is not of type zn.deps
"""
znjson.register(RawNodeAttributeConverter)
znjson.config.register(RawNodeAttributeConverter)
new_node = node.load(name=node.node_name)
value = getattr(new_node, attribute)

znjson.deregister(RawNodeAttributeConverter)
znjson.config.deregister(RawNodeAttributeConverter)

def not_zn_deps_err() -> AttributeError:
"""Evaluate error message when raising the error"""
Expand Down

0 comments on commit f3b97ff

Please sign in to comment.