Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Jan 31, 2022
1 parent 166fad9 commit c338d88
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 50 deletions.
119 changes: 76 additions & 43 deletions examples/docs/01_Intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,25 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Initialized empty Git repository in C:/Users/fabia/AppData/Local/Temp/tmpean3t5ot/.git/\n",
"Initialized DVC repository.\n",
"\n",
"You can now commit the changes to git.\n",
"\n",
"+---------------------------------------------------------------------+\n",
"| |\n",
"| DVC has enabled anonymous aggregate usage analytics. |\n",
"| Read the analytics documentation (and how to opt-out) here: |\n",
"| <https://dvc.org/doc/user-guide/analytics> |\n",
"| |\n",
"+---------------------------------------------------------------------+\n",
"\n",
"What's next?\n",
"------------\n",
"- Check out the documentation: <https://dvc.org/doc>\n",
"- Get help and share ideas: <https://dvc.org/chat>\n",
"- Star us on GitHub: <https://github.com/iterative/dvc>\n"
"Initialized empty Git repository in /tmp/tmptn0i8j7r/.git/\r\n",
"Initialized DVC repository.\r\n",
"\r\n",
"You can now commit the changes to git.\r\n",
"\r\n",
"\u001B[31m+---------------------------------------------------------------------+\r\n",
"\u001B[0m\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n",
"\u001B[31m|\u001B[0m DVC has enabled anonymous aggregate usage analytics. \u001B[31m|\u001B[0m\r\n",
"\u001B[31m|\u001B[0m Read the analytics documentation (and how to opt-out) here: \u001B[31m|\u001B[0m\r\n",
"\u001B[31m|\u001B[0m <\u001B[36mhttps://dvc.org/doc/user-guide/analytics\u001B[39m> \u001B[31m|\u001B[0m\r\n",
"\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n",
"\u001B[31m+---------------------------------------------------------------------+\r\n",
"\u001B[0m\r\n",
"\u001B[33mWhat's next?\u001B[39m\r\n",
"\u001B[33m------------\u001B[39m\r\n",
"- Check out the documentation: <\u001B[36mhttps://dvc.org/doc\u001B[39m>\r\n",
"- Get help and share ideas: <\u001B[36mhttps://dvc.org/chat\u001B[39m>\r\n",
"- Star us on GitHub: <\u001B[36mhttps://github.com/iterative/dvc\u001B[39m>\r\n",
"\u001B[0m"
]
}
],
Expand Down Expand Up @@ -125,10 +126,6 @@
" number = zn.outs()\n",
" maximum = zn.params()\n",
"\n",
" def __init__(self, maximum=None, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.maximum = maximum\n",
"\n",
" def run(self):\n",
" self.number = randrange(self.maximum)"
]
Expand All @@ -137,7 +134,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now call `run_and_save()` to create our random number"
"Notice that ZnTrack will automatically generate an `__init__` for all `zn.params`.\n",
"When writing a custom `__init__` it is important to add `super().__init__(*args, **kwargs)` for ZnTrack to work.\n",
"```python\n",
" def __init__(self, maximum=None, *args, **kwargs):\n",
" super().__init__(*args, **kwargs)\n",
" self.maximum = maximum\n",
"```\n",
"\n",
"\n",
"We can now call `run_and_save()` to create our random number."
]
},
{
Expand Down Expand Up @@ -167,21 +173,51 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2022-01-14 17:16:17,378 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
"2022-01-21 16:35:09,327 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
"Submit issues to https://github.com/zincware/ZnTrack.\n",
"2022-01-14 17:16:17,380 (WARNING): Converting 01_Intro.ipynb to file RandomNumber.py\n",
"2022-01-14 17:16:19,850 (WARNING): --- Writing new DVC file! ---\n",
"2022-01-14 17:16:19,850 (WARNING): You will not be able to see the stdout/stderr of the process in real time!\n",
"2022-01-14 17:16:22,159 (INFO): Running stage 'RandomNumber':\r\n",
"> python -c \"from src.RandomNumber import RandomNumber; RandomNumber.load(name='RandomNumber').run_and_save()\" \r\n",
"Creating 'dvc.yaml'\r\n",
"Adding stage 'RandomNumber' in 'dvc.yaml'\r\n",
"Generating lock file 'dvc.lock'\r\n",
"Updating lock file 'dvc.lock'\r\n",
"\r\n",
"To track the changes with git, run:\r\n",
"\r\n",
"\tgit add dvc.yaml dvc.lock 'nodes\\RandomNumber\\.gitignore'\r\n",
"2022-01-21 16:35:09,328 (WARNING): Converting 01_Intro.ipynb to file RandomNumber.py\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[NbConvertApp] Converting notebook 01_Intro.ipynb to script\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-01-21 16:35:12,397 (WARNING): --- Writing new DVC file! ---\n",
"2022-01-21 16:35:12,398 (WARNING): You will not be able to see the stdout/stderr of the process in real time!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[NbConvertApp] Writing 3241 bytes to 01_Intro.py\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-01-21 16:35:15,953 (INFO): Running stage 'RandomNumber':\n",
"> python3 -c \"from src.RandomNumber import RandomNumber; RandomNumber.load(name='RandomNumber').run_and_save()\" \n",
"Creating 'dvc.yaml'\n",
"Adding stage 'RandomNumber' in 'dvc.yaml'\n",
"Generating lock file 'dvc.lock'\n",
"Updating lock file 'dvc.lock'\n",
"\n",
"To track the changes with git, run:\n",
"\n",
" git add dvc.lock dvc.yaml nodes/RandomNumber/.gitignore\n",
"\n",
"To enable auto staging, run:\n",
"\n",
"\tdvc config core.autostage true\n",
"\n"
]
}
Expand All @@ -204,7 +240,7 @@
"outputs": [
{
"data": {
"text/plain": "473"
"text/plain": "125"
},
"execution_count": 7,
"metadata": {},
Expand All @@ -217,13 +253,10 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {
"nbsphinx": "hidden",
"tags": [],
"pycharm": {
"is_executing": true
}
"tags": []
},
"outputs": [],
"source": [
Expand Down
18 changes: 18 additions & 0 deletions tests/integration_tests/test_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,21 @@ def test_dvc_deps_node(proj_path):
assert DVCDepsNode.load(name="simple_test").dependency_file == pathlib.Path(
"test_traj.txt"
)


class SingleNodeNoInit(Node):
param1 = zn.params()
param2 = zn.params()

result = zn.outs()

def run(self):
self.result = self.param1 + self.param2


def test_auto_init(proj_path):
SingleNodeNoInit(param1=25, param2=42).write_graph(no_exec=False)

assert SingleNodeNoInit.load().param1 == 25
assert SingleNodeNoInit.load().param2 == 42
assert SingleNodeNoInit.load().result == 25 + 42
29 changes: 23 additions & 6 deletions tests/unit_tests/utlis/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,23 @@
import os
import pathlib

import pytest
import znjson

from zntrack.utils import cwd_temp_dir, decode_dict, is_jsonable
from zntrack.utils import utils


def test_is_jsonable():
"""Test for is_jsonable
Test is performed for a serializable dictionary and a non-serializable function.
"""
assert is_jsonable({"a": 1}) is True
assert is_jsonable({"a": is_jsonable}) is False
assert utils.is_jsonable({"a": 1}) is True
assert utils.is_jsonable({"a": utils.is_jsonable}) is False


def test_cwd_temp_dir():
new_dir = cwd_temp_dir(required_files=[__file__])
new_dir = utils.cwd_temp_dir(required_files=[__file__])
assert pathlib.Path(new_dir.name) == pathlib.Path(os.getcwd())
assert next(pathlib.Path(new_dir.name).glob("*.py")).name == "test_utils.py"
os.chdir("..")
Expand All @@ -39,5 +40,21 @@ def test_decode_dict_path():
dict_string = json.dumps(path, cls=znjson.ZnEncoder)
loaded_dict = json.loads(dict_string)
assert loaded_dict == {"_type": "pathlib.Path", "value": "test.txt"}
assert decode_dict(loaded_dict) == path
assert decode_dict(None) is None
assert utils.decode_dict(loaded_dict) == path
assert utils.decode_dict(None) is None


class Test:
pass


def test_get_auto_init():
with pytest.raises(TypeError):
Test(foo="foo")

setattr(Test, "__init__", utils.get_auto_init(fields=["foo", "bar"]))

test = Test(foo="foo", bar="bar")

assert test.foo == "foo"
assert test.bar == "bar"
18 changes: 17 additions & 1 deletion zntrack/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from zntrack.core.dvcgraph import GraphWriter
from zntrack.utils.config import config
from zntrack.utils.utils import deprecated
from zntrack.utils.utils import deprecated, get_auto_init
from zntrack.zn import params

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,6 +50,21 @@ def __call__(self, *args, **kwargs):
"""Still here for a depreciation warning for migrating to class based ZnTrack"""
pass

def __init_subclass__(cls, **kwargs):
"""Add a dataclass-like init if None is provided"""

# User provides an __init__
if cls.__dict__.get("__init__") is not None:
return cls

# attach an automatically generated __init__ if None is provided
zn_option_fiels = []
for name, item in cls.__dict__.items():
if isinstance(item, params):
zn_option_fiels.append(name)

setattr(cls, "__init__", get_auto_init(fields=zn_option_fiels))

def save(self, results: bool = False):
"""Save Class state to files
Expand Down
22 changes: 22 additions & 0 deletions zntrack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import shutil
import tempfile
import typing

import znjson

Expand Down Expand Up @@ -92,3 +93,24 @@ def wrapper(*args, **kwargs):
def decode_dict(value):
"""Decode dict that was loaded without znjson"""
return json.loads(json.dumps(value), cls=znjson.ZnDecoder)


def get_auto_init(fields: typing.List[str]):
"""Automatically create a __init__ based on fields
Parameters
----------
fields: list[str]
A list of strings that will be used in the __init__, e.g. for [foo, bar]
it will create __init__(self, foo=None, bar=None) using **kwargs
"""

def auto_init(self, **kwargs):
"""Wrapper for the __init__"""
for field in fields:
try:
setattr(self, field, kwargs.pop(field))
except KeyError:
pass
super(type(self), self).__init__(**kwargs)

return auto_init

0 comments on commit c338d88

Please sign in to comment.