Skip to content

Commit

Permalink
zn.nodes updates (#529)
Browse files Browse the repository at this point in the history
* add outs to parammeter Node

* use --force

* support lists in zn.nodes
  • Loading branch information
PythonFZ authored Mar 17, 2023
1 parent a61a560 commit a7bea46
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 29 deletions.
44 changes: 44 additions & 0 deletions tests/integration/test_zn_nodes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class NodeViaParams(Node):
param1 = zn.params()
outs = zn.outs()

def run(self):
raise NotImplementedError
Expand All @@ -20,6 +21,14 @@ def run(self):
self.outs = self.params1.param1 + self.params2.param1


class ExampleNodeLst(Node):
params: list = zn.nodes()
outs = zn.outs()

def run(self):
self.outs = sum([p.param1 for p in self.params])


@pytest.mark.parametrize("eager", [True, False])
def test_ExampleNode(proj_path, eager):
project = Project()
Expand All @@ -36,9 +45,44 @@ def test_ExampleNode(proj_path, eager):
assert node.params2.param1 == 10
assert node.outs == 11

assert node.params1.name == "ExampleNode_params1"
assert node.params2.name == "ExampleNode_params2"

if not eager:
# Check new instance also works
node = node.from_rev()
assert node.params1.param1 == 1
assert node.params2.param1 == 10
assert node.outs == 11

assert node.params1.name == "ExampleNode_params1"
assert node.params2.name == "ExampleNode_params2"


@pytest.mark.parametrize("eager", [True, False])
def test_ExampleNodeLst(proj_path, eager):
project = Project()
parameter_1 = NodeViaParams(param1=1)
parameter_2 = NodeViaParams(param1=10)

with project:
node = ExampleNodeLst(params=[parameter_1, parameter_2])

project.run(eager=eager)
if not eager:
node.load()
assert node.params[0].param1 == 1
assert node.params[1].param1 == 10
assert node.outs == 11

assert node.params[0].name == "ExampleNodeLst_params_0"
assert node.params[1].name == "ExampleNodeLst_params_1"

if not eager:
# Check new instance also works
node = node.from_rev()
assert node.params[0].param1 == 1
assert node.params[1].param1 == 10
assert node.outs == 11
assert node.params[0].name == "ExampleNodeLst_params_0"
assert node.params[1].name == "ExampleNodeLst_params_1"
2 changes: 1 addition & 1 deletion zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_dvc_cmd(
field_cmds = []
for attr in zninit.get_descriptors(Field, self=node):
field_cmds += attr.get_stage_add_argument(node)
optionals.append(attr.get_optional_dvc_cmd(node))
optionals += attr.get_optional_dvc_cmd(node)
for field_cmd in set(field_cmds):
cmd += list(field_cmd)

Expand Down
2 changes: 1 addition & 1 deletion zntrack/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
for x in self.get_files(instance)
]

def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[str]:
def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[typing.List[str]]:
"""Get optional dvc commands that will be executed beside the main dvc command.
This could be 'plots modify ...' or 'stage add --name node_helper'
Expand Down
77 changes: 50 additions & 27 deletions zntrack/fields/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def default(self, value, **kwargs):
return value


class NodeFiled(Dependency):
class NodeField(Dependency):
"""Add another Node as a field.
The other Node will provide its parameters and methods to be used.
Expand All @@ -355,45 +355,68 @@ class NodeFiled(Dependency):

def __set__(self, instance, value):
"""Disbale the _graph_ in the value 'Node'."""
if hasattr(value, "_graph_"):
value._graph_ = None
else:
raise TypeError(f"The value must be a Node and not {value}.")
for entry in value if isinstance(value, (list, tuple)) else [value]:
if hasattr(entry, "_graph_"):
entry._graph_ = None
else:
raise TypeError(f"The value must be a Node and not {entry}.")
return super().__set__(instance, value)

def get_node_name(self, instance) -> str:
def get_node_names(self, instance) -> list:
"""Get the name of the other Node."""
return f"{instance.name}_{self.name}"
value = instance.__dict__[self.name]
if isinstance(value, (list, tuple)):
return [f"{instance.name}_{self.name}_{idx}" for idx in range(len(value))]
return [f"{instance.name}_{self.name}"]

def save(self, instance: "Node"):
"""Save the Node parameters to disk."""
value = instance.__dict__[self.name]
if value is LazyOption:
return
_SaveNodes()(value, name=self.get_node_name(instance))
if not isinstance(value, (list, tuple)):
value = [value]

for node, name in zip(value, self.get_node_names(instance)):
_SaveNodes()(node, name=name)
super().save(instance)

def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[tuple]:
def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:
"""Get the dvc command for this field."""
name = self.get_node_name(instance)
node = instance.__dict__[self.name]
if not isinstance(node, znflow.Node):
raise TypeError(f"The value must be a Node and not {node}.")
module = module_handler(node.__class__)
return [
"stage",
"add",
"--name",
name,
"--outs",
f"nodes/{name}/hash",
f"zntrack run {module}.{node.__class__.__name__} --name {name} --hash-only",
]
names = self.get_node_names(instance)
nodes = instance.__dict__[self.name]
if not isinstance(nodes, (list, tuple)):
nodes = [nodes]

cmd = []

for node, name in zip(nodes, names):
if not isinstance(node, znflow.Node):
raise TypeError(f"The value must be a Node and not {node}.")
module = module_handler(node.__class__)
cmd.append(
[
"stage",
"add",
"--name",
name,
"--force",
"--outs",
f"nodes/{name}/hash",
(
f"zntrack run {module}.{node.__class__.__name__} --name"
f" {name} --hash-only"
),
]
)

return cmd

def get_files(self, instance: "Node") -> list:
"""Get the files affected by this field."""
name = self.get_node_name(instance)
return [pathlib.Path(f"nodes/{name}/hash")]
return [
pathlib.Path(f"nodes/{name}/hash") for name in self.get_node_names(instance)
]


def params(*args, **kwargs) -> Params:
Expand Down Expand Up @@ -421,6 +444,6 @@ def plots(*args, **kwargs) -> Plots:
return Plots(*args, **kwargs)


def nodes(*args, **kwargs) -> NodeFiled:
def nodes(*args, **kwargs) -> NodeField:
"""Create a node field."""
return NodeFiled(*args, **kwargs)
return NodeField(*args, **kwargs)

0 comments on commit a7bea46

Please sign in to comment.