From 7d388cddc9d5eef6a39a8a426357621ed75997a9 Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Mon, 9 May 2022 10:24:44 +0200 Subject: [PATCH] add `get_origin` (#291) --- tests/integration_tests/test_getdeps.py | 27 +++++++++++++++++ zntrack/zn/dependencies.py | 40 +++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/tests/integration_tests/test_getdeps.py b/tests/integration_tests/test_getdeps.py index 95e4b1c3..f76a93ae 100644 --- a/tests/integration_tests/test_getdeps.py +++ b/tests/integration_tests/test_getdeps.py @@ -10,6 +10,7 @@ from zntrack import getdeps, utils, zn from zntrack.core import ZnTrackOption from zntrack.core.base import Node +from zntrack.zn.dependencies import NodeAttribute, get_origin @pytest.fixture() @@ -161,3 +162,29 @@ def test_stacked_name_getdeps_2(proj_path, steps): for step in range(1, steps): assert ModifyNumber[f"rld_{step}"].outputs == 1 + + node_attr = get_origin(ModifyNumber[f"rld_{step}"], "inputs") + assert isinstance(node_attr, NodeAttribute) + assert node_attr.name == f"rld_{step - 1}" + + +def test_get_origin(proj_path): + sd = SeedNumber(inputs=20) + sd.write_graph() + ModifyNumber(inputs=getdeps(sd, "number")).write_graph() + + node_attr = get_origin(ModifyNumber.load(), "inputs") + assert isinstance(node_attr, NodeAttribute) + assert node_attr.name == "SeedNumber" + + +def test_err_get_origin(proj_path): + sd = SeedNumber(inputs=20) + sd.write_graph() + ModifyNumber(inputs=getdeps(sd, "number")).write_graph() + + with pytest.raises(AttributeError): + get_origin(ModifyNumber.load(), "outputs") + + with pytest.raises(AttributeError): + get_origin(SeedNumber, "inputs") diff --git a/zntrack/zn/dependencies.py b/zntrack/zn/dependencies.py index 81e35018..b5600ff9 100644 --- a/zntrack/zn/dependencies.py +++ b/zntrack/zn/dependencies.py @@ -4,6 +4,8 @@ import pathlib from typing import TYPE_CHECKING, List, Union +import znjson + from zntrack.utils import utils if TYPE_CHECKING: @@ -19,6 +21,25 @@ class NodeAttribute: affected_files: List[pathlib.Path] +class RawNodeAttributeConverter(znjson.ConverterBase): + """Serializer for Node Attributes + + Instead of returning the actual attribute this returns the NodeAttribute cls. + """ + + instance = NodeAttribute + representation = "NodeAttribute" + level = 999 + + def _encode(self, obj: NodeAttribute) -> dict: + """Convert NodeAttribute to serializable dict""" + return dataclasses.asdict(obj) + + def _decode(self, value: dict) -> NodeAttribute: + """return serialized Node attribute""" + return NodeAttribute(**value) + + def getdeps(node: Union[Node, type(Node)], attribute: str) -> NodeAttribute: """Allow for Node attributes as dependencies @@ -41,3 +62,22 @@ def getdeps(node: Union[Node, type(Node)], attribute: str) -> NodeAttribute: attribute=attribute, affected_files=list(node.affected_files), ) + + +def get_origin(node: Union[Node, type(Node)], attribute: str) -> NodeAttribute: + """Get the NodeAttribute from a zn.deps + + Typically, when using zn.deps there is no way to access the original Node where + the data comes from. This function allows you to get the underlying + NodeAttribute object to access e.g. the name of the original Node. + """ + znjson.register(RawNodeAttributeConverter) + new_node = node.load(name=node.node_name) + try: + value = getattr(new_node, attribute) + except AttributeError as err: + raise AttributeError("Can only use get_origin with zn.deps") from err + znjson.deregister(RawNodeAttributeConverter) + if not isinstance(value, NodeAttribute): + raise AttributeError("Can only use get_origin with zn.deps using getdeps.") + return value