Skip to content

Commit

Permalink
376 make default type zntrackoption (#383)
Browse files Browse the repository at this point in the history
* add default and more tests

* update black
  • Loading branch information
PythonFZ authored Sep 14, 2022
1 parent 346fb4c commit 9ae3161
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Black Check
uses: psf/black@stable
with:
version: "22.6.0"
version: "22.8.0"

isort:
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 22.8.0
hooks:
- id: black

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Sphinx = "^5.0.2"
numpy = "^1.23"
matplotlib = "^3.5.2"
ase = "^3.22.1"
black = "^22.6.0"
black = "^22.8.0"
isort = "^5.10.1"
flake8 = "^5.0.2"
pre-commit = "^2.20.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ def test_example_func_dry_run(proj_path):
"params.yaml:example_func",
"--outs",
"test.txt",
f'{utils.get_python_interpreter()} -c "from test_single_function import '
'example_func; example_func(exec_func=True)" ',
(
f'{utils.get_python_interpreter()} -c "from test_single_function import '
'example_func; example_func(exec_func=True)" '
),
]
)

Expand Down
42 changes: 38 additions & 4 deletions tests/unit_tests/core/test_dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ class ExampleDVCOutsNode(GraphWriter):
outs = dvc.outs(pathlib.Path("example.dat"))


class ExampleDVCOutsParams(GraphWriter):
is_loaded = False
outs = dvc.outs(pathlib.Path("example.dat"))
param1 = zn.params(5)


def test_get_dvc_arguments():
dvc_options = DVCRunOptions(
force=True,
Expand Down Expand Up @@ -163,8 +169,10 @@ def test_prepare_dvc_script():
"--force",
"--deps",
"file.txt",
f'{utils.get_python_interpreter()} -c "from src.file import MyNode;'
' MyNode.load().run_and_save()" ',
(
f'{utils.get_python_interpreter()} -c "from src.file import MyNode;'
' MyNode.load().run_and_save()" '
),
]

script = prepare_dvc_script(
Expand All @@ -190,8 +198,10 @@ def test_prepare_dvc_script():
"file.txt",
"--deps",
"src/file.py",
f'{utils.get_python_interpreter()} -c "from src.file import MyNode;'
' MyNode.load().run_and_save()" ',
(
f'{utils.get_python_interpreter()} -c "from src.file import MyNode;'
' MyNode.load().run_and_save()" '
),
]


Expand All @@ -211,6 +221,30 @@ def test_ZnTrackInfo_collect():

assert example.zntrack.collect(zn.params) == {"param1": 1, "param2": 2}

# show all
assert example.zntrack.collect() == {"param1": 1, "param2": 2}

# no zn.outs available
assert example.zntrack.collect(zn.outs) == {}

example_with_outs = ExampleDVCOutsNode()
assert example_with_outs.zntrack.collect(dvc.outs) == {
"outs": pathlib.Path("example.dat")
}
assert example_with_outs.zntrack.collect() == {"outs": pathlib.Path("example.dat")}
assert example_with_outs.zntrack.collect(zn.params) == {}

example_outs_params = ExampleDVCOutsParams()

assert example_outs_params.zntrack.collect(dvc.outs) == {
"outs": pathlib.Path("example.dat")
}
assert example_outs_params.zntrack.collect(zn.params) == {"param1": 5}
assert example_outs_params.zntrack.collect() == {
"outs": pathlib.Path("example.dat"),
"param1": 5,
}


@pytest.mark.parametrize(
("param1", "param2"),
Expand Down
7 changes: 5 additions & 2 deletions zntrack/core/dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,16 @@ class ZnTrackInfo:
def __init__(self, parent):
self._parent = parent

def collect(self, zntrackoption: typing.Type[descriptor.BaseDescriptorType]) -> dict:
def collect(
self, zntrackoption: typing.Type[descriptor.BaseDescriptorType] = ZnTrackOption
) -> dict:
"""Collect the values of all ZnTrackOptions of the passed type
Parameters
----------
zntrackoption:
Any cls of a ZnTrackOption such as zn.params
Any cls of a ZnTrackOption such as zn.params.
By default, collect all ZnTrackOptions
Returns
-------
Expand Down

0 comments on commit 9ae3161

Please sign in to comment.