diff --git a/examples/docs/09_lazy.ipynb b/examples/docs/09_lazy.ipynb index 18a7ce85..58400126 100644 --- a/examples/docs/09_lazy.ipynb +++ b/examples/docs/09_lazy.ipynb @@ -58,13 +58,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Initialized empty Git repository in /tmp/tmp9xwenc20/.git/\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Initialized empty Git repository in /tmp/tmp45orcu_7/.git/\n", "Initialized DVC repository.\n", "\n", "You can now commit the changes to git.\n", @@ -132,7 +126,10 @@ }, "outputs": [], "source": [ - "class PrintOption(zntrack.zn.Output):\n", + "from zntrack.fields.zn.options import Output\n", + "\n", + "\n", + "class PrintOption(Output):\n", " def __init__(self):\n", " super().__init__(dvc_option=\"outs\", use_repr=False)\n", " # zntrack will try dvc --PrintOption outs.json\n", @@ -155,7 +152,18 @@ "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3870159/2524714966.py:2: DeprecationWarning: Use 'zntrack.params' instead.\n", + " start = zntrack.zn.params()\n", + "/tmp/ipykernel_3870159/2524714966.py:3: DeprecationWarning: Use 'zntrack.params' instead.\n", + " stop = zntrack.zn.params()\n" + ] + } + ], "source": [ "class RandomNumber(zntrack.Node):\n", " start = zntrack.zn.params()\n", @@ -211,20 +219,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Running DVC command: 'repro'\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u0000" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + "Running DVC command: 'repro'\n", "\u0000" ] } @@ -247,7 +242,15 @@ "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u0000" + ] + } + ], "source": [ "random_number.load(lazy=False)" ] @@ -275,7 +278,7 @@ { "data": { "text/plain": [ - "710" + "598" ] }, "execution_count": 9, @@ -351,7 +354,7 @@ { "data": { "text/plain": [ - "710" + "598" ] }, "execution_count": 11, @@ -386,7 +389,16 @@ "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3870159/790841409.py:2: DeprecationWarning: Use 'zntrack.deps' instead.\n", + " deps = zntrack.zn.deps()\n" + ] + } + ], "source": [ "class AddOne(zntrack.Node):\n", " deps = zntrack.zn.deps()\n", @@ -797,7 +809,7 @@ { "data": { "text/plain": [ - "36" + "71" ] }, "execution_count": 17, @@ -825,7 +837,7 @@ { "data": { "text/plain": [ - "36" + "71" ] }, "execution_count": 18, @@ -852,7 +864,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": { "collapsed": false, "jupyter": { @@ -882,8 +894,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Running DVC command: 'repro'\n", - "\u0000Running DVC command: 'stage add --name RandomNumber --force ...'\n" + "Running DVC command: 'repro'\n" ] }, { @@ -897,6 +908,7 @@ "name": "stderr", "output_type": "stream", "text": [ + "\u0000Running DVC command: 'stage add --name RandomNumber --force ...'\n", "\u0000" ] }, @@ -968,7 +980,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": { "collapsed": false, "jupyter": { @@ -983,7 +995,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "585 == 585 == 585\n" + "3008 == 3008 == 3008\n" ] } ], @@ -999,7 +1011,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": { "collapsed": false, "jupyter": { @@ -1014,7 +1026,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "1365 != 1783 != 585\n" + "1647 != 849 != 3008\n" ] } ], @@ -1035,7 +1047,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": { "collapsed": false, "jupyter": { @@ -1049,13 +1061,6 @@ "source": [ "temp_dir.cleanup()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/poetry.lock b/poetry.lock index 21805821..3a5d6e1e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5508,13 +5508,13 @@ dask = ["bokeh (>=2.4.2,<3.0.0)", "dask (>=2022.12.1,<2023.0.0)", "dask-jobqueue [[package]] name = "zninit" -version = "0.1.10" +version = "0.1.11" description = "Descriptor based dataclass implementation" optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "zninit-0.1.10-py3-none-any.whl", hash = "sha256:a019f47b33bb27822694f4c076dcb679f8e90c55712602508c42d99d84de0b37"}, - {file = "zninit-0.1.10.tar.gz", hash = "sha256:c00366581901f3f73573c50d7396fadc781e94d70e4450e489e85939d4903c24"}, + {file = "zninit-0.1.11-py3-none-any.whl", hash = "sha256:78e5c7e1d0d50c131cf8efab58aad4a213446a1c38f38360731bf11b778e29e1"}, + {file = "zninit-0.1.11.tar.gz", hash = "sha256:fc14a55bd85a38f8f1411bd319cac541c2e6e16ab402f3f11a94e182b0681965"}, ] [package.extras] @@ -5534,4 +5534,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0.0" -content-hash = "4c1d2a547d8848de688353b32ce86d53043cab1d33c79cfa33a3aba51022b20c" +content-hash = "094010c5f8f6ae2f7fa142031fa5e11f4527d7cb6062a397fb2057bf488a9608" diff --git a/pyproject.toml b/pyproject.toml index 3b9bbbed..c2572fb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ZnTrack" -version = "0.6.3" +version = "0.7.0" description = "Create, Run and Benchmark DVC Pipelines in Python" authors = ["zincwarecode "] license = "Apache-2.0" @@ -17,9 +17,10 @@ pandas = "^2" typer = "^0.7.0" dot4dict = "^0.1.1" -zninit = "^0.1.9" +zninit = "^0.1.11" znjson = "^0.2.2" znflow = "^0.1.14" +typing-extensions = "^4.8.0" [tool.poetry.urls] diff --git a/tests/integration/test_combine_lists.py b/tests/integration/test_combine_lists.py index 9026c841..3e6e5873 100644 --- a/tests/integration/test_combine_lists.py +++ b/tests/integration/test_combine_lists.py @@ -12,7 +12,7 @@ def run(self): class AddOneToList(zntrack.Node): - data = zntrack.zn.deps() + data = zntrack.deps() outs = zntrack.zn.outs() def run(self) -> None: @@ -20,7 +20,7 @@ def run(self) -> None: class AddOneToDict(zntrack.Node): - data = zntrack.zn.deps() + data = zntrack.deps() outs = zntrack.zn.outs() def run(self) -> None: diff --git a/tests/integration/test_example_02.py b/tests/integration/test_example_02.py index 230286f3..14947147 100644 --- a/tests/integration/test_example_02.py +++ b/tests/integration/test_example_02.py @@ -43,8 +43,8 @@ def run(self): class ComputeAB(zntrack.Node): """Node stage AB, depending on A&B""" - a: ComputeA = zntrack.zn.deps() - b: ComputeB = zntrack.zn.deps() + a: ComputeA = zntrack.deps() + b: ComputeB = zntrack.deps() out = zntrack.zn.outs() param = zntrack.zn.params("default") diff --git a/tests/integration/test_example_03.py b/tests/integration/test_example_03.py index 1a9ac55a..0202c5eb 100644 --- a/tests/integration/test_example_03.py +++ b/tests/integration/test_example_03.py @@ -9,7 +9,7 @@ def run(self): class AddOne(zntrack.Node): - inp = zntrack.zn.deps() + inp = zntrack.deps() number = zntrack.zn.outs() def run(self): @@ -17,7 +17,7 @@ def run(self): class SubtractOne(zntrack.Node): - inp = zntrack.zn.deps() + inp = zntrack.deps() number = zntrack.zn.outs() def run(self): @@ -27,7 +27,7 @@ def run(self): class Summation(zntrack.Node): """Stage that is actually tested, containing the multiple dependencies""" - inp = zntrack.zn.deps() + inp = zntrack.deps() number = zntrack.zn.outs() def run(self): @@ -40,7 +40,7 @@ class SummationTuple(zntrack.Node): Additionally testing for tuple conversion here! """ - inp = zntrack.zn.deps() + inp = zntrack.deps() number = zntrack.zn.outs() def run(self): diff --git a/tests/integration/test_node_node_getitem.py b/tests/integration/test_node_node_getitem.py index eac10688..e5d80de6 100644 --- a/tests/integration/test_node_node_getitem.py +++ b/tests/integration/test_node_node_getitem.py @@ -14,8 +14,8 @@ def run(self): class AddList(zntrack.Node): - a = zntrack.zn.deps() - b = zntrack.zn.deps() + a = zntrack.deps() + b = zntrack.deps() output = zntrack.zn.outs() diff --git a/tests/integration/test_node_nwd.py b/tests/integration/test_node_nwd.py index c910dfaa..abd51a4d 100644 --- a/tests/integration/test_node_nwd.py +++ b/tests/integration/test_node_nwd.py @@ -16,7 +16,7 @@ def run(self): class FileToOuts(zntrack.Node): # although, this is a file path, it has to be zn.deps - file = zntrack.zn.deps() + file = zntrack.deps() text = zntrack.zn.outs() def run(self): diff --git a/tests/integration/test_none_values.py b/tests/integration/test_none_values.py index 9df10ca2..f2555ed2 100644 --- a/tests/integration/test_none_values.py +++ b/tests/integration/test_none_values.py @@ -6,7 +6,7 @@ class LoadFromDeps(zntrack.Node): - data = zntrack.zn.deps() + data = zntrack.deps() file: pathlib.Path = zntrack.dvc.deps() result = zntrack.zn.outs() diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index 611378e0..f155fe85 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -12,7 +12,7 @@ class ZnNodesNode(zntrack.Node): """Used zn.nodes""" - node = zntrack.zn.nodes() + node = zntrack.deps() result = zntrack.zn.outs() def run(self) -> None: @@ -330,7 +330,7 @@ def test_groups_nwd(tmp_path_2): ) -def test_groups_nwd_zn_nodes(tmp_path_2): +def test_groups_nwd_zn_nodes_a(tmp_path_2): node = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") with zntrack.Project(automatic_node_names=True) as project: node_1 = ZnNodesNode(node=node) @@ -339,14 +339,21 @@ def test_groups_nwd_zn_nodes(tmp_path_2): with project.group("CustomGroup") as group_2: node_3 = ZnNodesNode(node=node) + assert node_1.name == "ZnNodesNode" + assert node_1.node.name == "ZnNodesNode+node" + + assert node_2.name == "Group1_ZnNodesNode" + assert node_2.node.name == "Group1_ZnNodesNode+node" + + assert node_3.name == "CustomGroup_ZnNodesNode" + assert node_3.node.name == "CustomGroup_ZnNodesNode+node" + project.run() - assert zntrack.from_rev(node_1).node.nwd == pathlib.Path("nodes/ZnNodesNode_node") - assert zntrack.from_rev(node_2).node.nwd == pathlib.Path( - "nodes", "Group1", "ZnNodesNode_1_node" - ) - assert zntrack.from_rev(node_3).node.nwd == pathlib.Path( - "nodes", "CustomGroup", "ZnNodesNode_1_node" + assert zntrack.from_rev(node_1).nwd == pathlib.Path("nodes/ZnNodesNode") + assert zntrack.from_rev(node_2).nwd == pathlib.Path("nodes", "Group1", "ZnNodesNode") + assert zntrack.from_rev(node_3).nwd == pathlib.Path( + "nodes", "CustomGroup", "ZnNodesNode" ) project.load() @@ -355,7 +362,7 @@ def test_groups_nwd_zn_nodes(tmp_path_2): assert node_3.result == "Lorem Ipsum" -def test_groups_nwd_zn_nodes(tmp_path_2): +def test_groups_nwd_zn_nodes_b(tmp_path_2): node = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") with zntrack.Project(automatic_node_names=True) as project: with project.group() as group_1: @@ -365,11 +372,9 @@ def test_groups_nwd_zn_nodes(tmp_path_2): project.run() - assert zntrack.from_rev(node_2).node.nwd == pathlib.Path( - "nodes", "Group1", "ZnNodesNode_node" - ) - assert zntrack.from_rev(node_3).node.nwd == pathlib.Path( - "nodes", "CustomGroup", "ZnNodesNode_node" + assert zntrack.from_rev(node_2).nwd == pathlib.Path("nodes", "Group1", "ZnNodesNode") + assert zntrack.from_rev(node_3).nwd == pathlib.Path( + "nodes", "CustomGroup", "ZnNodesNode" ) project.load() diff --git a/tests/integration/test_zn_nodes2.py b/tests/integration/test_zn_nodes2.py index 83ab20b0..d50801f4 100644 --- a/tests/integration/test_zn_nodes2.py +++ b/tests/integration/test_zn_nodes2.py @@ -2,6 +2,7 @@ import pytest +import zntrack from zntrack import Node, Project, exceptions, zn @@ -14,8 +15,8 @@ def run(self): class ExampleNode(Node): - params1: NodeViaParams = zn.nodes() - params2: NodeViaParams = zn.nodes() + params1: NodeViaParams = zntrack.deps() + params2: NodeViaParams = zntrack.deps() outs = zn.outs() @@ -24,7 +25,7 @@ def run(self): class ExampleNodeLst(Node): - params: list = zn.nodes() + params: list = zntrack.deps() outs = zn.outs() def run(self): @@ -47,8 +48,8 @@ def test_ExampleNode(proj_path, eager): assert node.params2.param1 == 10 assert node.outs == 11 - assert node.params1.name == "ExampleNode_params1" - assert node.params2.name == "ExampleNode_params2" + assert node.params1.name == "ExampleNode+params1" + assert node.params2.name == "ExampleNode+params2" if not eager: # Check new instance also works @@ -57,8 +58,8 @@ def test_ExampleNode(proj_path, eager): assert node.params2.param1 == 10 assert node.outs == 11 - assert node.params1.name == "ExampleNode_params1" - assert node.params2.name == "ExampleNode_params2" + assert node.params1.name == "ExampleNode+params1" + assert node.params2.name == "ExampleNode+params2" @pytest.mark.parametrize("git_only_repo", [True, False]) @@ -78,8 +79,8 @@ def test_ExampleNodeLst(proj_path, eager, git_only_repo): assert node.params[1].param1 == 10 assert node.outs == 11 - assert node.params[0].name == "ExampleNodeLst_params_0" - assert node.params[1].name == "ExampleNodeLst_params_1" + assert node.params[0].name == "ExampleNodeLst+params+0" + assert node.params[1].name == "ExampleNodeLst+params+1" if not eager: # Check new instance also works @@ -87,8 +88,8 @@ def test_ExampleNodeLst(proj_path, eager, git_only_repo): assert nodex.params[0].param1 == 1 assert nodex.params[1].param1 == 10 assert nodex.outs == 11 - assert nodex.params[0].name == "ExampleNodeLst_params_0" - assert nodex.params[1].name == "ExampleNodeLst_params_1" + assert nodex.params[0].name == "ExampleNodeLst+params+0" + assert nodex.params[1].name == "ExampleNodeLst+params+1" parameter_1.param1 = 2 # Change parameter assert isinstance(parameter_1, NodeViaParams) @@ -111,26 +112,17 @@ def test_ExampleNodeLst(proj_path, eager, git_only_repo): assert node.outs == 12 -def test_znodes_on_graph(proj_path): - project = Project(force=True) - with project: - with pytest.raises(exceptions.ZnNodesOnGraphError): - _ = ExampleNodeLst(params=NodeViaParams(param1=1)) - - with project: - with pytest.raises(exceptions.ZnNodesOnGraphError): - _ = ExampleNodeLst(params=[NodeViaParams(param1=1)]) - - class RandomNumberGen(Node): def get_rnd(self): import random + random.seed(42) + return random.random() class ExampleNodeWithRandomNumberGen(Node): - rnd: RandomNumberGen = zn.nodes() + rnd: RandomNumberGen = zntrack.deps() outs = zn.outs() diff --git a/tests/integration/test_zntrack_deps.py b/tests/integration/test_zntrack_deps.py new file mode 100644 index 00000000..1a6c09fc --- /dev/null +++ b/tests/integration/test_zntrack_deps.py @@ -0,0 +1,232 @@ +"""Tests for 'zntrack.deps'-field which can be used as both `zntrack.zn.deps` and `zntrack.zn.nodes`.""" + +import zntrack.examples + +# TODO: change the parameters, rerun and see if it updates! + + +def test_as_deps(proj_path): + """Test for 'zntrack.deps' acting as `zntrack.zn.deps`-like field.""" + project = zntrack.Project(automatic_node_names=True) + + with project: + a = zntrack.examples.ComputeRandomNumber(params_file="a.json") + b = zntrack.examples.ComputeRandomNumber(params_file="b.json") + c = zntrack.examples.SumRandomNumbers([a, b]) + + a.write_params(min=1, max=5, seed=42) + b.write_params(min=5, max=10, seed=42) + + project.run() + + a.load() + b.load() + c.load() + + assert a.number == 1 + assert b.number == 10 + assert c.result == 11 + + a.write_params(min=1, max=5, seed=31415) + # b.write_params(min=5, max=10, seed=31415) # only change one of the two parameters + + project.repro() + + a.load() + b.load() + c.load() + + assert a.number == 5 + assert b.number == 10 + assert c.result == 15 + + +def test_as_nodes(proj_path): + """Test for 'zntrack.deps' acting as `zntrack.zn.nodes`-like field.""" + project = zntrack.Project(automatic_node_names=True) + + a = zntrack.examples.ComputeRandomNumber(params_file="a.json") + b = zntrack.examples.ComputeRandomNumber(params_file="b.json") + + with project: + c = zntrack.examples.SumRandomNumbers([a, b]) + + a.write_params(min=1, max=5, seed=42) + b.write_params(min=5, max=10, seed=42) + + project.run() + + # TODO: good error messages when someone tries to load a node that is not on the graph + # a.load() + # b.load() + # assert a.number == 1 + # assert b.number == 10 + + c.load() + assert c.result == 11 + + a.write_params(min=1, max=5, seed=31415) + + project.repro() + + c.load() + assert c.result == 15 + + +def test_mixed(proj_path): + project = zntrack.Project(automatic_node_names=True) + + a = zntrack.examples.ComputeRandomNumber(params_file="a.json") + + with project: + b = zntrack.examples.ComputeRandomNumber(params_file="b.json") + c = zntrack.examples.SumRandomNumbers([a, b]) + + a.write_params(min=1, max=5, seed=42) + b.write_params(min=5, max=10, seed=42) + + project.run() + + b.load() + c.load() + + assert b.number == 10 + assert c.result == 11 + + a.write_params(min=1, max=5, seed=31415) + + project.repro() + + c.load() + assert c.result == 15 + + b.write_params(min=5, max=10, seed=31415) + + project.repro() + + b.load() + c.load() + + assert b.number == 9 + assert c.result == 14 + + +def test_named_parent(proj_path): + project = zntrack.Project(automatic_node_names=True) + + a = zntrack.examples.ComputeRandomNumber(params_file="a.json") + b = zntrack.examples.ComputeRandomNumber(params_file="b.json") + + with project: + c = zntrack.examples.SumRandomNumbersNamed([a, b], name="c") + + a.write_params(min=1, max=5, seed=42) + b.write_params(min=5, max=10, seed=42) + + project.run() + + c.load() + assert c.name == "c" + assert c.result == 11 + + +def test_one_to_many(proj_path): + project = zntrack.Project(automatic_node_names=True) + + a = zntrack.examples.ComputeRandomNumber(params_file="a.json") + + with project: + b = zntrack.examples.SumRandomNumbers([a]) + c = zntrack.examples.SumRandomNumbers([a]) + + a.write_params(min=1, max=5, seed=42) + # here we have one parameter file for both b and c + # so a change in 'a.json' will affect both 'b' and 'c + + project.run() + + b.load() + c.load() + + assert b.result == 1 + assert c.result == 1 + + a.write_params(min=1, max=5, seed=31415) + + project.repro() + + b.load() + c.load() + + assert b.result == 5 + assert c.result == 5 + + +def test_one_to_many_params(proj_path): + project = zntrack.Project(automatic_node_names=True) + + a = zntrack.examples.ComputeRandomNumberWithParams(min=1, max=5, seed=42) + + with project: + b = zntrack.examples.SumRandomNumbers([a]) + c = zntrack.examples.SumRandomNumbers([a]) + + # here we create a deepcopy of 'a' for both 'b' and 'c' + # so a change in 'a' will not affect 'b' and 'c' + # and we can change the parameters in b.numbers[0] and c.numbers[0] independently + + project.run() + + b.load() + c.load() + + assert b.result == 1 + assert c.result == 1 + + assert b.name == "SumRandomNumbers" + assert c.name == "SumRandomNumbers_1" + + assert b.numbers[0].name == f"{b.name}+numbers+0" + assert c.numbers[0].name == f"{c.name}+numbers+0" + + b.numbers[0].min = 5 + b.numbers[0].max = 10 + b.numbers[0].seed = 42 + + project.run() + + b.load() + c.load() + + assert b.result == 10 + assert c.result == 1 + + c.numbers[0].min = 5 + c.numbers[0].max = 10 + c.numbers[0].seed = 42 + + project.run() + + b.load() + c.load() + + assert b.result == 10 + assert c.result == 10 + + +# Currently this is neither tested nor supported. +# A Node not on Graph error is raised. # TODO better error messages +# def test_one_to_many_params_property(proj_path): +# project = zntrack.Project(automatic_node_names=True) + +# a = zntrack.examples.AddNumbersProperty(a=1, b=2) +# b = zntrack.examples.AddNumbersProperty(a=3, b=4) + +# with project: +# c = zntrack.examples.AddNodeAttributes(a.c, b.c) + +# project.run() + +# c.load() + +# assert c.c == 10 diff --git a/tests/test_zntrack.py b/tests/test_zntrack.py index 2a4edeb9..a09aac2f 100644 --- a/tests/test_zntrack.py +++ b/tests/test_zntrack.py @@ -4,4 +4,4 @@ def test_version(): """Test 'ZnTrack' version.""" - assert __version__ == "0.6.3" + assert __version__ == "0.7.0" diff --git a/tests/unit_tests/test_zn_deps.py b/tests/unit_tests/test_zn_deps.py index 8cf9b108..3d6a13a2 100644 --- a/tests/unit_tests/test_zn_deps.py +++ b/tests/unit_tests/test_zn_deps.py @@ -12,7 +12,7 @@ def run(self): class DependentNode(zntrack.Node): - deps = zntrack.zn.deps() + deps = zntrack.deps() def run(self): pass diff --git a/zntrack/core/load.py b/zntrack/core/load.py index 2d52490a..f83bd7dd 100644 --- a/zntrack/core/load.py +++ b/zntrack/core/load.py @@ -3,6 +3,7 @@ import contextlib import importlib import importlib.util +import json import pathlib import sys import tempfile @@ -14,6 +15,7 @@ import dvc.stage from zntrack.core.node import Node +from zntrack.utils import config T = typing.TypeVar("T", bound=Node) @@ -93,14 +95,40 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T: """ if isinstance(name, Node): name = name.name - stage = _get_stage(name, remote, rev) - - cmd = stage.cmd - run_str = cmd.split()[2] - name = cmd.split()[4] - - package_and_module, cls_name = run_str.rsplit(".", 1) - module = None + if "+" in name: + fs = dvc.api.DVCFileSystem(url=remote, rev=rev) + + components = name.split("+") + + if len(components) == 3: + parent, attribute, key = components + else: + parent, attribute = components + key = None + + with fs.open(config.files.zntrack) as fs: + zntrack_config = json.load(fs) + data = zntrack_config[parent][attribute] + if key is not None: + try: + data = data[int(key)] + except (ValueError, KeyError): + data = data[key] + assert ( + data["_type"] == "zntrack.Node" + ), f"Expected zntrack.Node, got {data['_type']}" + package_and_module = data["value"]["module"] + cls_name = data["value"]["cls"] + module = None + else: + stage = _get_stage(name, remote, rev) + + cmd = stage.cmd + run_str = cmd.split()[2] + name = cmd.split()[4] + + package_and_module, cls_name = run_str.rsplit(".", 1) + module = None try: module = importlib.import_module(package_and_module) except ModuleNotFoundError: diff --git a/zntrack/core/node.py b/zntrack/core/node.py index 61f3ce47..09c0308a 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -142,6 +142,7 @@ class Node(zninit.ZnInit, znflow.Node): _name_ = None _protected_ = znflow.Node._protected_ + ["name"] + _priority_kwargs_ = ["name"] @property def _use_repr_(self) -> bool: @@ -173,10 +174,13 @@ def convert_notebook(cls, nb_name: str = None): @property def _init_descriptors_(self): from zntrack import fields + from zntrack.fields.dependency import Dependency + from zntrack.fields.zn import options as zn_options return [ - fields.zn.Params, - fields.zn.Dependency, + zn_options.Params, + zn_options.Dependency, + Dependency, fields.meta.Text, fields.meta.Environment, fields.dvc.DVCOption, @@ -377,7 +381,7 @@ def get_dvc_cmd( @dataclasses.dataclass class NodeIdentifier: - """All information that uniquly identifies a node.""" + """All information that uniquely identifies a node.""" module: str cls: str @@ -388,7 +392,7 @@ class NodeIdentifier: @classmethod def from_node(cls, node: Node): """Create a _NodeIdentifier from a Node object.""" - # TODO module and cls are not needed (from_rev can handle name, rev, remote only) + # TODO module and cls are only required for `zn.nodes` return cls( module=module_handler(node), cls=node.__class__.__name__, diff --git a/zntrack/examples/__init__.py b/zntrack/examples/__init__.py index dab32723..722581b2 100644 --- a/zntrack/examples/__init__.py +++ b/zntrack/examples/__init__.py @@ -2,6 +2,11 @@ These nodes are primarily used for testing and demonstration purposes. """ +import json +import pathlib +import random +import typing as t + import pandas as pd import zntrack @@ -53,6 +58,18 @@ def run(self): self.c = self.a + self.b +class AddNumbersProperty(zntrack.Node): + """Add two numbers.""" + + a = zntrack.params() + b = zntrack.params() + + @property + def c(self): + """Add two numbers.""" + return self.a + self.b + + class AddNodes(zntrack.Node): """Add two nodes.""" @@ -120,3 +137,70 @@ class WriteDVCOuts(zntrack.Node): def run(self): """Write an output file.""" self.outs.write_text(str(self.params)) + + +class ComputeRandomNumber(zntrack.Node): + """Compute a random number.""" + + params_file = zntrack.params_path() + + number = zntrack.outs() + + def _post_init_(self): + self.params_file = pathlib.Path(self.params_file) + + def run(self): + """Compute a random number.""" + self.number = self.get_random_number() + + def get_random_number(self): + """Compute a random number.""" + params = json.loads(self.params_file.read_text()) + random.seed(params["seed"]) + return random.randint(params["min"], params["max"]) + + def write_params(self, min, max, seed): + """Write params to file.""" + self.params_file.write_text(json.dumps({"min": min, "max": max, "seed": seed})) + + +class ComputeRandomNumberWithParams(zntrack.Node): + """Compute a random number.""" + + min: int = zntrack.params() + max: int = zntrack.params() + seed: int = zntrack.params() + + number = zntrack.outs() + + def run(self): + """Compute a random number.""" + self.number = self.get_random_number() + + def get_random_number(self): + """Compute a random number.""" + random.seed(self.seed) + return random.randint(self.min, self.max) + + +class ComputeRandomNumberNamed(ComputeRandomNumber): + """Same as ComputeRandomNumber but with a custom name.""" + + _name_ = "custom_ComputeRandomNumber" + + +class SumRandomNumbers(zntrack.Node): + """Sum a list of random numbers.""" + + numbers: t.List[ComputeRandomNumber] = zntrack.deps() + result: int = zntrack.outs() + + def run(self): + """Sum a list of random numbers.""" + self.result = sum(x.get_random_number() for x in self.numbers) + + +class SumRandomNumbersNamed(SumRandomNumbers): + """Same as SumRandomNumbers but with a custom name.""" + + _name_ = "custom_SumRandomNumbers" diff --git a/zntrack/fields/dependency.py b/zntrack/fields/dependency.py new file mode 100644 index 00000000..58736613 --- /dev/null +++ b/zntrack/fields/dependency.py @@ -0,0 +1,304 @@ +"""Dependency field.""" + +import copy +import json +import logging +import pathlib +import typing as t + +import znflow +import zninit +import znjson +from znflow import handler + +from zntrack.fields.field import DataIsLazyError, Field, FieldGroup, LazyField +from zntrack.fields.zn.options import ( + CombinedConnectionsConverter, + ConnectionConverter, + _default, + _get_all_connections_and_instances, +) +from zntrack.utils import config, update_key_val + +log = logging.getLogger(__name__) + +if t.TYPE_CHECKING: + from zntrack import Node + + +class Dependency(LazyField): + """A dependency field.""" + + dvc_option = "deps" + group = FieldGroup.PARAMETER + + def __init__(self, default=_default): + """Create a new dependency field. + + A `zn.deps` does not support default values. + To build a dependency graph, the values must be passed at runtime. + """ + if default is _default: + super().__init__() + elif default is None: + super().__init__(default=default) + else: + raise ValueError( + "A dependency field does not support default dependencies. You can only" + " use 'None' to declare this an optional dependency" + f"and not {default}." + ) + + def __set__(self, instance, value): + """Disable the _graph_ in the value 'Node'.""" + if value is None: + return super().__set__(instance, value) + + # We need to update the node names, if they are not on the graph. + # TODO: raise error if '+' in name + + graph = instance._graph_ + if isinstance(graph, znflow.DiGraph): + with znflow.disable_graph(): + if isinstance(value, dict): + new_entries = { + key: self._update_node_name(entry, instance, graph, key=key) + for key, entry in value.items() + } + value = new_entries + elif isinstance(value, (list, tuple)): + new_entries = [ + self._update_node_name(entry, instance, graph, key=idx) + for idx, entry in enumerate(value) + ] + value = new_entries + else: + value = self._update_node_name(value, instance, graph) + + return super().__set__(instance, value) + + def _get_nodes_on_off_graph(self, instance) -> t.Tuple[list, list]: + """Get the nodes that are on the graph and off the graph. + + Get the values of this descriptor and split them into + nodes that are on the graph and off the graph. + These represent `zn.deps` and `zn.nodes` respectively. + + + Attributes + ---------- + instance : Node + The Node instance. + + Returns + ------- + on_graph : list + The nodes that are on the graph. + off_graph : list + The nodes that are off the graph. + """ + values = getattr(instance, self.name) + # TODO use IterableHandler? + + if isinstance(values, dict): + values = list(values.values()) + + if isinstance(values, tuple): + values = list(values) + + if not isinstance(values, list): + values = [values] + + nodes = [] + for entry in values: + if isinstance(entry, (znflow.CombinedConnections, znflow.Connection)): + nodes.extend(_get_all_connections_and_instances(entry)) + else: + nodes.append(entry) + + on_graph = [] + off_graph = [] + for entry in nodes: + try: + if "+" in entry.name: + # currently there is no other way to check if a node is on the graph + # a node which is not on the graph will have a node name containing a + # colon, which is not allowed in node names on the graph by DVC. + off_graph.append(entry) + else: + on_graph.append(entry) + except AttributeError: + # in eager mode the attribute does not have a name. + pass + return on_graph, off_graph + + def get_files(self, instance) -> list: + """Get the affected files of the respective Nodes.""" + files = [] + + value, _ = self._get_nodes_on_off_graph(instance) + + for node in value: + node: Node + if node is None: + continue + if node._external_: + from zntrack.utils import run_dvc_cmd + + # TODO save these files in a specific directory called `external` + # TODO the `dvc import cmd` should not run here but rather be a stage? + + deps_file = pathlib.Path("external", f"{node.uuid}.json") + deps_file.parent.mkdir(exist_ok=True, parents=True) + + # zntrack run node.name --external \ + # --remote node.state.remote --rev node.state.rev + + # when combining with zn.nodes this should be used + # dvc stage add --params params.yaml: + # --outs nodes//node-meta.json zntrack run --external + + cmd = [ + "import", + node.state.remote if node.state.remote is not None else ".", + (node.nwd / "node-meta.json").as_posix(), + "-o", + deps_file.as_posix(), + ] + if node.state.rev is not None: + cmd.extend(["--rev", node.state.rev]) + # TODO how can we test, that the loaded file truly is the correct one? + if not deps_file.exists(): + run_dvc_cmd(cmd) + files.append(deps_file.as_posix()) + # dvc import node-meta.json + add as dependency file + continue + # if node.state.rev is not None or node.state.remote is not None: + # # TODO if the Node has a `rev` or `remote` attribute, we need to + # # get the UUID file of the respective Node through node.state.fs.open + # # save that somewhere (can't use NWD, because we can now have multiple + # # nodes with the same name...) + # # and make the uuid a dependency of the node. + # continue + files.append(node.nwd / "node-meta.json") + for field in zninit.get_descriptors(Field, self=node): + if field.dvc_option in ["params", "deps"]: + # We do not want to depend on parameter files or + # recursively on dependencies. + continue + files.extend(field.get_files(node)) + log.debug(f"Found field {field} and extended files to {files}") + return files + + def save(self, instance: "Node"): + """Save the field to disk.""" + try: + value = self.get_value_except_lazy(instance) + except DataIsLazyError: + return + + _, off_graph = self._get_nodes_on_off_graph(instance) + + for node in off_graph: + node.save(results=False) + + self._write_value_to_config( + value, + instance, + encoder=znjson.ZnEncoder.from_converters( + [ConnectionConverter, CombinedConnectionsConverter], add_default=True + ), + ) + + def get_data(self, instance: "Node") -> any: + """Get the value of the field from the file.""" + zntrack_dict = json.loads( + instance.state.fs.read_text(config.files.zntrack), + ) + value = zntrack_dict[instance.name][self.name] + + value = update_key_val(value, instance=instance) + + value = json.loads( + json.dumps(value), + cls=znjson.ZnDecoder.from_converters( + [ConnectionConverter, CombinedConnectionsConverter], add_default=True + ), + ) + + # Up until here we have connection objects. Now we need + # to resolve them to Nodes. The Nodes, as in 'connection.instance' + # are already loaded by the ZnDecoder. + return handler.UpdateConnectors()(value) + + def get_stage_add_argument(self, instance) -> t.List[tuple]: + """Get the dvc command for this field.""" + cmd = [ + (f"--{self.dvc_option}", pathlib.Path(file).as_posix()) + for file in self.get_files(instance) + ] + + _, off_graph = self._get_nodes_on_off_graph(instance) + + # TODO this is only for parameters via `zn.params` + # we need to also handle parameters via `dvc.params` + + from zntrack.fields.zn.options import Params + + # NO: we have to do this for each value and for instance + + for node in off_graph: + for field in zninit.get_descriptors(Field, self=node): + if isinstance(field, Params): + # cmd += [("--params", f"{config.files.params}:{node.name}:")] + cmd += [("--params", f"{config.files.params}:{node.name}")] + elif field.dvc_option == "params": + files = field.get_files(node) + for file in files: + cmd.append(("--params", f"{file}:")) + return cmd + + def _update_node_name(self, entry, instance, graph, key=None): + """Update the node name if it is used as 'zn.nodes'. + + Attributes + ---------- + self : Dependency + The Dependency field, used to gather the attribute name. + entry : list[nodes]|dict[str, nodes]|nodes + The entries to update. + instance : Node + The parent Node instance the 'zn.nodes' is connected to + graph : znflow.DiGraph + The active graph. + key : str|int + The key or index of the entry. + + Returns + ------- + entry : list[nodes]|dict[str, nodes]|nodes + A deepcopy of the entries with updated names. + + """ + if isinstance(entry, (znflow.CombinedConnections, znflow.Connection)): + # we currently do not support CombinedConnections or Connection + return entry + + if hasattr(entry, "_graph_"): + if ( + entry.state.rev is not None + or entry.state.remote is not None + or entry._external_ + ): + # This indicates a loaded node which we do not want to change. + return entry + + if entry.uuid not in graph: + entry._graph_ = None + entry = copy.deepcopy(entry) + entry_name = f"{instance.name}+{self.name}" + if key is not None: + entry_name += f"+{key}" + entry.name = entry_name + + return entry diff --git a/zntrack/fields/dvc/__init__.py b/zntrack/fields/dvc/__init__.py index 20ee07f3..ff9724be 100644 --- a/zntrack/fields/dvc/__init__.py +++ b/zntrack/fields/dvc/__init__.py @@ -1,185 +1,64 @@ -"""DVC fields without serialization of data / for file paths.""" -import json -import pathlib -import typing - -import znjson - -from zntrack.fields.field import Field, FieldGroup, PlotsMixin -from zntrack.utils import node_wd - -if typing.TYPE_CHECKING: - from zntrack import Node - - -class DVCOption(Field): - """A field that is used as a dvc option. - - The DVCOption field is designed for paths only. - """ - - group = FieldGroup.PARAMETER - - def __init__(self, *args, **kwargs): - """Create a DVCOption field.""" - if node_wd.nwd in args or node_wd.nwd in kwargs.values(): - raise ValueError( - "Can not set `zntrack.nwd` as value for {self}. Please use" - " `zntrack.nwd/...` to create a path relative to the node working" - " directory." - ) - self.dvc_option = kwargs.pop("dvc_option") - super().__init__(*args, **kwargs) - - def get_files(self, instance: "Node") -> list: - """Get the files affected by this field. - - Parameters - ---------- - instance : Node - The node instance to get the files for. - - Returns - ------- - list of str - A list of file paths affected by this field. - - """ - value = getattr(instance, self.name) - if not isinstance(value, list): - value = [value] - return [pathlib.Path(file).as_posix() for file in value if file is not None] - - def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]: - """Get the dvc command for this field. - - Parameters - ---------- - instance : Node - The node instance to get the command for. - - Returns - ------- - list of tuple of str - A list of command-line arguments to use when adding - this field to the DVC stage. - - """ - if self.dvc_option == "params": - return [ - (f"--{self.dvc_option}", f"{file}:") for file in self.get_files(instance) - ] - else: - return [(f"--{self.dvc_option}", file) for file in self.get_files(instance)] - - def get_data(self, instance: "Node") -> any: - """Get the value of the field from the configuration file. - - Parameters - ---------- - instance : Node - The Node instance to get the field value for. - decoder : Any, optional - The decoder to use when parsing the configuration file, by default None. - - Returns - ------- - any - The value of the field from the configuration file. - """ - zntrack_dict = json.loads( - instance.state.fs.read_text("zntrack.json"), - ) - return json.loads( - json.dumps(zntrack_dict[instance.name][self.name]), cls=znjson.ZnDecoder - ) - - def save(self, instance: "Node"): - """Save the field to config file. - - Parameters - ---------- - instance : Node - The node instance to save the field for. - - """ - try: - value = instance.__dict__[self.name] - except KeyError: - try: - # default value is not stored in __dict__ - # TODO: not sure if I like this - value = getattr(instance, self.name) - except AttributeError: - return - self._write_value_to_config(value, instance, encoder=znjson.ZnEncoder) - - def __get__(self, instance: "Node", owner=None): - """Add replacement of the nwd to the get method. - - Parameters - ---------- - instance : Node - The node instance to get the value for. - owner : type, optional - The owner class of the descriptor, by default None - - Returns - ------- - Any - The value of the attribute. - - """ - if instance is None: - return self - value = super().__get__(instance, owner) - return node_wd.ReplaceNWD()(value, nwd=instance.nwd) - - -class PlotsOption(PlotsMixin, DVCOption): - """Field with DVC plots kwargs.""" +"""Deprecated module for 'zntrack.dvc.