From 7693daa8f8554a9b4eff9552774d768eccbd0147 Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Fri, 17 Mar 2023 16:24:57 +0100 Subject: [PATCH] Small fix (#532) * add jobs argument * fix #526 * use ruff instead of isort * fix #523 * fix #528 * bugfix --- .github/workflows/lint.yaml | 3 --- .pre-commit-config.yaml | 5 ----- pyproject.toml | 2 +- tests/integration/test_file_changes.py | 16 ++++++++-------- tests/integration/test_node_nwd.py | 21 +++++++++++++++++++++ zntrack/fields/zn/__init__.py | 4 ++-- zntrack/project/zntrack_project.py | 4 ++-- zntrack/utils/node_wd.py | 3 ++- 8 files changed, 36 insertions(+), 22 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 068c172a..5775c8ed 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -23,9 +23,6 @@ jobs: - name: black if: always() run: black --check . - - name: isort - if: always() - run: isort --check-only . - name: ruff if: always() run: ruff . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 01a85011..c630dd57 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,11 +8,6 @@ repos: hooks: - id: black - - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 - hooks: - - id: isort - - repo: https://github.com/charliermarsh/ruff-pre-commit rev: v0.0.238 hooks: diff --git a/pyproject.toml b/pyproject.toml index e529628e..f07b44df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,7 +105,7 @@ disable = [ [tool.ruff] line-length = 90 -select = ["E", "F", "D", "N", "C"] #, "ANN"] +select = ["E", "F", "D", "N", "C", "I"] #, "ANN"] extend-ignore = [ "D213", "D203", "D401", diff --git a/tests/integration/test_file_changes.py b/tests/integration/test_file_changes.py index bb99aa99..ff908064 100644 --- a/tests/integration/test_file_changes.py +++ b/tests/integration/test_file_changes.py @@ -41,11 +41,11 @@ def run(self): self.outs = "correct outs" -# def test_WriteToOutsOutsideRun(proj_path): -# node = WriteToOutsOutsideRun(outs="correct outs") -# node.run() -# node.save() -# -# assert WriteToOutsOutsideRun.from_rev().outs == "correct outs" -# # WriteToOutsOutsideRun(outs="incorrect outs").save() -# # assert WriteToOutsOutsideRun.from_rev().outs == "correct outs" +def test_WriteToOutsOutsideRun(proj_path): + node = WriteToOutsOutsideRun(outs="correct outs") + node.run() + node.save() + + assert WriteToOutsOutsideRun.from_rev().outs == "correct outs" + WriteToOutsOutsideRun(outs="incorrect outs").save(results=False) + assert WriteToOutsOutsideRun.from_rev().outs == "correct outs" diff --git a/tests/integration/test_node_nwd.py b/tests/integration/test_node_nwd.py index 75da5c1c..1cce8ca2 100644 --- a/tests/integration/test_node_nwd.py +++ b/tests/integration/test_node_nwd.py @@ -14,6 +14,14 @@ def run(self): self.file[0].write_text(self.text) +class OutsAsNWD(zntrack.Node): + text = zntrack.zn.params() + outs: pathlib.Path = zntrack.dvc.outs(zntrack.nwd) + + def run(self): + (self.outs / "test.txt").write_text(self.text) + + @pytest.mark.parametrize("eager", [True, False]) def test_WriteToNWD(proj_path, eager): with zntrack.Project() as project: @@ -25,3 +33,16 @@ def test_WriteToNWD(proj_path, eager): if not eager: write_to_nwd.load() assert write_to_nwd.__dict__["file"] == [pathlib.Path("$nwd$", "test.txt")] + + +@pytest.mark.parametrize("eager", [True, False]) +def test_OutAsNWD(proj_path, eager): + with zntrack.Project() as project: + outs_as_nwd = OutsAsNWD(text="Hello World") + + project.run(eager=eager) + assert (outs_as_nwd.outs / "test.txt").read_text() == "Hello World" + assert outs_as_nwd.outs == pathlib.Path("nodes", "OutsAsNWD") + if not eager: + outs_as_nwd.load() + assert outs_as_nwd.__dict__["outs"] == zntrack.nwd diff --git a/zntrack/fields/zn/__init__.py b/zntrack/fields/zn/__init__.py index 6748398e..0f2867b4 100644 --- a/zntrack/fields/zn/__init__.py +++ b/zntrack/fields/zn/__init__.py @@ -338,9 +338,9 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]: class _SaveNodes(znflow.utils.IterableHandler): def default(self, value, **kwargs): name = kwargs["name"] - if hasattr(value, "save"): + if isinstance(value, znflow.Node): value.name = name - value.save() + value.save(results=False) return value diff --git a/zntrack/project/zntrack_project.py b/zntrack/project/zntrack_project.py index 8c32f3dd..aaa0cb6f 100644 --- a/zntrack/project/zntrack_project.py +++ b/zntrack/project/zntrack_project.py @@ -137,9 +137,9 @@ def create_experiment(self, name: str = None, queue: bool = True) -> Experiment: cmd.extend(["--name", name]) exp.name = capture_run_dvc_cmd(cmd).split("'")[1] - def run_exp(self) -> None: + def run_exp(self, jobs: int = 1) -> None: """Run all queued experiments.""" - run_dvc_cmd(["exp", "run", "--run-all"]) + run_dvc_cmd(["exp", "run", "--run-all", "--jobs", str(jobs)]) @property def branches(self): diff --git a/zntrack/utils/node_wd.py b/zntrack/utils/node_wd.py index c7d28475..d829ee59 100644 --- a/zntrack/utils/node_wd.py +++ b/zntrack/utils/node_wd.py @@ -35,7 +35,8 @@ def default(self, value, **kwargs): """Replace the nwd placeholder with the actual nwd.""" if isinstance(value, str): if value == nwd: - return nwd + # nwd is of type str but will be converted to pathlib.Path + return pathlib.Path(kwargs["nwd"]) return value.replace(nwd, pathlib.Path(kwargs["nwd"]).as_posix()) elif isinstance(value, pathlib.Path): return pathlib.Path(