Skip to content

Commit

Permalink
Update znflow (#535)
Browse files Browse the repository at this point in the history
* update znflow

* bump version

* fix znflow imports

* fix zn.nodes(None)
  • Loading branch information
PythonFZ authored Mar 20, 2023
1 parent 5a269cc commit 16f895e
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 122 deletions.
224 changes: 112 additions & 112 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ZnTrack"
version = "0.6.0a1"
version = "0.6.0a2"
description = "Create, Run and Benchmark DVC Pipelines in Python"
authors = ["zincwarecode <zincwarecode@gmail.com>"]
license = "Apache-2.0"
Expand All @@ -10,15 +10,15 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.8,<4.0.0"
dvc = "^2.47.0"
dvc = "^2.50.0"
pyyaml = "^6.0"
tqdm = "^4.64.0"
pandas = "^1.4.3"
typer = "^0.7.0"

dot4dict = "^0.1.1"
zninit = "^0.1.9"
znflow = "^0.1.3"
znflow = "^0.1.4"
znjson = "^0.2.2"


Expand Down
12 changes: 12 additions & 0 deletions tests/integration/test_none_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,15 @@ def test_from_dvc_deps(proj_path, eager):
node.load()

assert node.result == "Hello World"


class EmptyNodesNode(zntrack.Node):
nodes = zntrack.zn.nodes(None)

def run(self):
pass


def test_EmptyNode(proj_path):
with zntrack.Project() as project:
node = EmptyNodesNode()
2 changes: 1 addition & 1 deletion tests/integration/test_zn_nodes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ExampleNodeLst(Node):
outs = zn.outs()

def run(self):
self.outs = sum([p.param1 for p in self.params])
self.outs = sum(p.param1 for p in self.params)


@pytest.mark.parametrize("eager", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zntrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

def test_version():
"""Test 'ZnTrack' version."""
assert __version__ == "0.6.0a1"
assert __version__ == "0.6.0a2"
13 changes: 10 additions & 3 deletions zntrack/fields/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import znflow.utils
import zninit
import znjson
from znflow import handler

from zntrack.fields.field import Field, FieldGroup, LazyField
from zntrack.utils import LazyOption, module_handler, update_key_val
Expand Down Expand Up @@ -325,7 +326,7 @@ def get_data(self, instance: "Node") -> any:
# Up until here we have connection objects. Now we need
# to resolve them to Nodes. The Nodes, as in 'connection.instance'
# are already loaded by the ZnDecoder.
return znflow.graph._UpdateConnectors()(value)
return handler.UpdateConnectors()(value)

def get_stage_add_argument(self, instance) -> typing.List[tuple]:
"""Get the dvc command for this field."""
Expand Down Expand Up @@ -355,6 +356,9 @@ class NodeField(Dependency):

def __set__(self, instance, value):
"""Disbale the _graph_ in the value 'Node'."""
if value is None:
return super().__set__(instance, value)

for entry in value if isinstance(value, (list, tuple)) else [value]:
if hasattr(entry, "_graph_"):
entry._graph_ = None
Expand All @@ -375,7 +379,7 @@ def save(self, instance: "Node"):
value = getattr(instance, self.name)
except AttributeError:
return
if value is LazyOption:
if value in [LazyOption, None]:
return
if not isinstance(value, (list, tuple)):
value = [value]
Expand All @@ -386,8 +390,11 @@ def save(self, instance: "Node"):

def get_optional_dvc_cmd(self, instance: "Node") -> typing.List[list]:
"""Get the dvc command for this field."""
names = self.get_node_names(instance)
nodes = instance.__dict__[self.name]
if nodes is None:
return []

names = self.get_node_names(instance)
if not isinstance(nodes, (list, tuple)):
nodes = [nodes]

Expand Down
4 changes: 2 additions & 2 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import git
import yaml
import znflow
from znflow.graph import _UpdateConnectors
from znflow.handler import UpdateConnectors

from zntrack.core.node import Node, get_dvc_cmd
from zntrack.utils import capture_run_dvc_cmd, run_dvc_cmd
Expand Down Expand Up @@ -48,7 +48,7 @@ def run(
if eager:
# update connectors
log.info(f"Running node {node}")
self._update_node_attributes(node, _UpdateConnectors())
self._update_node_attributes(node, UpdateConnectors())
node.run()
if save:
node.save()
Expand Down

0 comments on commit 16f895e

Please sign in to comment.