Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ authored Dec 10, 2024
2 parents 7be5838 + 8b5029a commit 477abdf
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
64 changes: 64 additions & 0 deletions tests/integration/test_post_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import zntrack


class NodeWithPostLoad(zntrack.Node):
params: float = zntrack.params()
outs: float = zntrack.outs()

def _post_load_(self):
self.params += 1

def run(self):
self.outs = self.params


class DepsOnPostLoadNode(zntrack.Node):
deps: float = zntrack.deps()
outs: float = zntrack.outs()

def run(self):
self.outs = self.deps


def test_post_load(proj_path):
project = zntrack.Project()
with project:
node = NodeWithPostLoad(params=1)

assert node.params == 1
project.repro()

n = node.from_rev()
assert n.params == 2 # modified by _post_load_
assert n.outs == 1


def test_post_load_deps(proj_path):
project = zntrack.Project()

with project:
node = NodeWithPostLoad(params=1)
dep_node = DepsOnPostLoadNode(deps=node.outs)
dep_node_2 = DepsOnPostLoadNode(deps=node.params)

project.repro()

# post load has not been called
assert dep_node.outs == 1
assert node.outs == 1
assert node.params == 1
# assert dep_node_2.deps == 2 # has not been resolved from a connection
assert dep_node_2.outs == 2

# post load has been called
n = dep_node.from_rev()
assert n.outs == 1
assert n.deps == 1

# _post_load_ will also be called if the node is resolved from a connection
n2 = dep_node_2.from_rev(name=dep_node_2.name)
assert n2.outs == 2
assert n2.deps == 2

assert node.from_rev().outs == 1
assert node.from_rev().params == 2
7 changes: 4 additions & 3 deletions zntrack/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,8 @@ def __post_init__(self):
if not znflow.get_graph() is not znflow.empty_graph:
self.name = self.__class__.__name__

@te.deprecated(
"The _post_load_ method was removed. Use __post_init__ in combination with `self.state` instead."
)
def _post_load_(self):
"""Called after `from_rev` is called."""
raise NotImplementedError

def run(self):
Expand Down Expand Up @@ -146,6 +144,9 @@ def from_rev(
_ = getattr(instance, field.name)

instance._external_ = True
if not running and hasattr(instance, "_post_load_"):
with contextlib.suppress(NotImplementedError):
instance._post_load_()

return instance

Expand Down

0 comments on commit 477abdf

Please sign in to comment.