diff --git a/tests/integration/test_post_load.py b/tests/integration/test_post_load.py new file mode 100644 index 00000000..421d2189 --- /dev/null +++ b/tests/integration/test_post_load.py @@ -0,0 +1,64 @@ +import zntrack + + +class NodeWithPostLoad(zntrack.Node): + params: float = zntrack.params() + outs: float = zntrack.outs() + + def _post_load_(self): + self.params += 1 + + def run(self): + self.outs = self.params + + +class DepsOnPostLoadNode(zntrack.Node): + deps: float = zntrack.deps() + outs: float = zntrack.outs() + + def run(self): + self.outs = self.deps + + +def test_post_load(proj_path): + project = zntrack.Project() + with project: + node = NodeWithPostLoad(params=1) + + assert node.params == 1 + project.repro() + + n = node.from_rev() + assert n.params == 2 # modified by _post_load_ + assert n.outs == 1 + + +def test_post_load_deps(proj_path): + project = zntrack.Project() + + with project: + node = NodeWithPostLoad(params=1) + dep_node = DepsOnPostLoadNode(deps=node.outs) + dep_node_2 = DepsOnPostLoadNode(deps=node.params) + + project.repro() + + # post load has not been called + assert dep_node.outs == 1 + assert node.outs == 1 + assert node.params == 1 + # assert dep_node_2.deps == 2 # has not been resolved from a connection + assert dep_node_2.outs == 2 + + # post load has been called + n = dep_node.from_rev() + assert n.outs == 1 + assert n.deps == 1 + + # _post_load_ will also be called if the node is resolved from a connection + n2 = dep_node_2.from_rev(name=dep_node_2.name) + assert n2.outs == 2 + assert n2.deps == 2 + + assert node.from_rev().outs == 1 + assert node.from_rev().params == 2 diff --git a/zntrack/node.py b/zntrack/node.py index 5e5f84e3..518deec7 100644 --- a/zntrack/node.py +++ b/zntrack/node.py @@ -54,10 +54,8 @@ def __post_init__(self): if not znflow.get_graph() is not znflow.empty_graph: self.name = self.__class__.__name__ - @te.deprecated( - "The _post_load_ method was removed. Use __post_init__ in combination with `self.state` instead." - ) def _post_load_(self): + """Called after `from_rev` is called.""" raise NotImplementedError def run(self): @@ -146,6 +144,9 @@ def from_rev( _ = getattr(instance, field.name) instance._external_ = True + if not running and hasattr(instance, "_post_load_"): + with contextlib.suppress(NotImplementedError): + instance._post_load_() return instance