Skip to content

Commit

Permalink
fix params for zn.nodes (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ authored Mar 17, 2023
1 parent a7bea46 commit aca814f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
28 changes: 24 additions & 4 deletions tests/integration/test_zn_nodes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,31 @@ def test_ExampleNodeLst(proj_path, eager):
assert node.params[0].name == "ExampleNodeLst_params_0"
assert node.params[1].name == "ExampleNodeLst_params_1"

if not eager:
# Check new instance also works
nodex = node.from_rev()
assert nodex.params[0].param1 == 1
assert nodex.params[1].param1 == 10
assert nodex.outs == 11
assert nodex.params[0].name == "ExampleNodeLst_params_0"
assert nodex.params[1].name == "ExampleNodeLst_params_1"

parameter_1.param1 = 2 # Change parameter
assert isinstance(parameter_1, NodeViaParams)
with project:
# # # node = ExampleNodeLst(params=[parameter_1, parameter_2])
node.params = [parameter_1, parameter_2]
project.run(eager=eager)

if not eager:
node.load()
assert node.params[0].param1 == 2
assert node.params[1].param1 == 10
assert node.outs == 12

if not eager:
# Check new instance also works
node = node.from_rev()
assert node.params[0].param1 == 1
assert node.params[0].param1 == 2
assert node.params[1].param1 == 10
assert node.outs == 11
assert node.params[0].name == "ExampleNodeLst_params_0"
assert node.params[1].name == "ExampleNodeLst_params_1"
assert node.outs == 12
44 changes: 28 additions & 16 deletions zntrack/fields/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ def get_node_names(self, instance) -> list:

def save(self, instance: "Node"):
"""Save the Node parameters to disk."""
value = instance.__dict__[self.name]
try:
value = getattr(instance, self.name)
except AttributeError:
return
if value is LazyOption:
return
if not isinstance(value, (list, tuple)):
Expand All @@ -393,22 +396,31 @@ def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:
for node, name in zip(nodes, names):
if not isinstance(node, znflow.Node):
raise TypeError(f"The value must be a Node and not {node}.")
node.name = name
module = module_handler(node.__class__)
cmd.append(
[
"stage",
"add",
"--name",
name,
"--force",
"--outs",
f"nodes/{name}/hash",
(
f"zntrack run {module}.{node.__class__.__name__} --name"
f" {name} --hash-only"
),
]
)

_cmd = [
"stage",
"add",
"--name",
name,
"--force",
"--outs",
f"nodes/{name}/hash",
]
field_cmds = []
for attr in zninit.get_descriptors(Params, self=node):
field_cmds += attr.get_stage_add_argument(node)

for field_cmd in set(field_cmds):
_cmd += list(field_cmd)

_cmd += [
f"zntrack run {module}.{node.__class__.__name__} --name"
f" {name} --hash-only"
]

cmd.append(_cmd)

return cmd

Expand Down

0 comments on commit aca814f

Please sign in to comment.