Skip to content

Commit

Permalink
add get_origin (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ authored May 9, 2022
1 parent 2276dbf commit 7d388cd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/integration_tests/test_getdeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from zntrack import getdeps, utils, zn
from zntrack.core import ZnTrackOption
from zntrack.core.base import Node
from zntrack.zn.dependencies import NodeAttribute, get_origin


@pytest.fixture()
Expand Down Expand Up @@ -161,3 +162,29 @@ def test_stacked_name_getdeps_2(proj_path, steps):

for step in range(1, steps):
assert ModifyNumber[f"rld_{step}"].outputs == 1

node_attr = get_origin(ModifyNumber[f"rld_{step}"], "inputs")
assert isinstance(node_attr, NodeAttribute)
assert node_attr.name == f"rld_{step - 1}"


def test_get_origin(proj_path):
sd = SeedNumber(inputs=20)
sd.write_graph()
ModifyNumber(inputs=getdeps(sd, "number")).write_graph()

node_attr = get_origin(ModifyNumber.load(), "inputs")
assert isinstance(node_attr, NodeAttribute)
assert node_attr.name == "SeedNumber"


def test_err_get_origin(proj_path):
sd = SeedNumber(inputs=20)
sd.write_graph()
ModifyNumber(inputs=getdeps(sd, "number")).write_graph()

with pytest.raises(AttributeError):
get_origin(ModifyNumber.load(), "outputs")

with pytest.raises(AttributeError):
get_origin(SeedNumber, "inputs")
40 changes: 40 additions & 0 deletions zntrack/zn/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pathlib
from typing import TYPE_CHECKING, List, Union

import znjson

from zntrack.utils import utils

if TYPE_CHECKING:
Expand All @@ -19,6 +21,25 @@ class NodeAttribute:
affected_files: List[pathlib.Path]


class RawNodeAttributeConverter(znjson.ConverterBase):
"""Serializer for Node Attributes
Instead of returning the actual attribute this returns the NodeAttribute cls.
"""

instance = NodeAttribute
representation = "NodeAttribute"
level = 999

def _encode(self, obj: NodeAttribute) -> dict:
"""Convert NodeAttribute to serializable dict"""
return dataclasses.asdict(obj)

def _decode(self, value: dict) -> NodeAttribute:
"""return serialized Node attribute"""
return NodeAttribute(**value)


def getdeps(node: Union[Node, type(Node)], attribute: str) -> NodeAttribute:
"""Allow for Node attributes as dependencies
Expand All @@ -41,3 +62,22 @@ def getdeps(node: Union[Node, type(Node)], attribute: str) -> NodeAttribute:
attribute=attribute,
affected_files=list(node.affected_files),
)


def get_origin(node: Union[Node, type(Node)], attribute: str) -> NodeAttribute:
"""Get the NodeAttribute from a zn.deps
Typically, when using zn.deps there is no way to access the original Node where
the data comes from. This function allows you to get the underlying
NodeAttribute object to access e.g. the name of the original Node.
"""
znjson.register(RawNodeAttributeConverter)
new_node = node.load(name=node.node_name)
try:
value = getattr(new_node, attribute)
except AttributeError as err:
raise AttributeError("Can only use get_origin with zn.deps") from err
znjson.deregister(RawNodeAttributeConverter)
if not isinstance(value, NodeAttribute):
raise AttributeError("Can only use get_origin with zn.deps using getdeps.")
return value

0 comments on commit 7d388cd

Please sign in to comment.