diff --git a/examples/docs/01_Intro.ipynb b/examples/docs/01_Intro.ipynb index dfc68c2c..db76324b 100644 --- a/examples/docs/01_Intro.ipynb +++ b/examples/docs/01_Intro.ipynb @@ -71,24 +71,25 @@ "name": "stdout", "output_type": "stream", "text": [ - "Initialized empty Git repository in C:/Users/fabia/AppData/Local/Temp/tmpean3t5ot/.git/\n", - "Initialized DVC repository.\n", - "\n", - "You can now commit the changes to git.\n", - "\n", - "+---------------------------------------------------------------------+\n", - "| |\n", - "| DVC has enabled anonymous aggregate usage analytics. |\n", - "| Read the analytics documentation (and how to opt-out) here: |\n", - "| |\n", - "| |\n", - "+---------------------------------------------------------------------+\n", - "\n", - "What's next?\n", - "------------\n", - "- Check out the documentation: \n", - "- Get help and share ideas: \n", - "- Star us on GitHub: \n" + "Initialized empty Git repository in /tmp/tmptn0i8j7r/.git/\r\n", + "Initialized DVC repository.\r\n", + "\r\n", + "You can now commit the changes to git.\r\n", + "\r\n", + "\u001B[31m+---------------------------------------------------------------------+\r\n", + "\u001B[0m\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n", + "\u001B[31m|\u001B[0m DVC has enabled anonymous aggregate usage analytics. \u001B[31m|\u001B[0m\r\n", + "\u001B[31m|\u001B[0m Read the analytics documentation (and how to opt-out) here: \u001B[31m|\u001B[0m\r\n", + "\u001B[31m|\u001B[0m <\u001B[36mhttps://dvc.org/doc/user-guide/analytics\u001B[39m> \u001B[31m|\u001B[0m\r\n", + "\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n", + "\u001B[31m+---------------------------------------------------------------------+\r\n", + "\u001B[0m\r\n", + "\u001B[33mWhat's next?\u001B[39m\r\n", + "\u001B[33m------------\u001B[39m\r\n", + "- Check out the documentation: <\u001B[36mhttps://dvc.org/doc\u001B[39m>\r\n", + "- Get help and share ideas: <\u001B[36mhttps://dvc.org/chat\u001B[39m>\r\n", + "- Star us on GitHub: <\u001B[36mhttps://github.com/iterative/dvc\u001B[39m>\r\n", + "\u001B[0m" ] } ], @@ -125,10 +126,6 @@ " number = zn.outs()\n", " maximum = zn.params()\n", "\n", - " def __init__(self, maximum=None, **kwargs):\n", - " super().__init__(**kwargs)\n", - " self.maximum = maximum\n", - "\n", " def run(self):\n", " self.number = randrange(self.maximum)" ] @@ -137,7 +134,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can now call `run_and_save()` to create our random number" + "Notice that ZnTrack will automatically generate an `__init__` for all `zn.params`.\n", + "When writing a custom `__init__` it is important to add `super().__init__(*args, **kwargs)` for ZnTrack to work.\n", + "```python\n", + " def __init__(self, maximum=None, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.maximum = maximum\n", + "```\n", + "\n", + "\n", + "We can now call `run_and_save()` to create our random number." ] }, { @@ -167,21 +173,51 @@ "name": "stdout", "output_type": "stream", "text": [ - "2022-01-14 17:16:17,378 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n", + "2022-01-21 16:35:09,327 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n", "Submit issues to https://github.com/zincware/ZnTrack.\n", - "2022-01-14 17:16:17,380 (WARNING): Converting 01_Intro.ipynb to file RandomNumber.py\n", - "2022-01-14 17:16:19,850 (WARNING): --- Writing new DVC file! ---\n", - "2022-01-14 17:16:19,850 (WARNING): You will not be able to see the stdout/stderr of the process in real time!\n", - "2022-01-14 17:16:22,159 (INFO): Running stage 'RandomNumber':\r\n", - "> python -c \"from src.RandomNumber import RandomNumber; RandomNumber.load(name='RandomNumber').run_and_save()\" \r\n", - "Creating 'dvc.yaml'\r\n", - "Adding stage 'RandomNumber' in 'dvc.yaml'\r\n", - "Generating lock file 'dvc.lock'\r\n", - "Updating lock file 'dvc.lock'\r\n", - "\r\n", - "To track the changes with git, run:\r\n", - "\r\n", - "\tgit add dvc.yaml dvc.lock 'nodes\\RandomNumber\\.gitignore'\r\n", + "2022-01-21 16:35:09,328 (WARNING): Converting 01_Intro.ipynb to file RandomNumber.py\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NbConvertApp] Converting notebook 01_Intro.ipynb to script\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2022-01-21 16:35:12,397 (WARNING): --- Writing new DVC file! ---\n", + "2022-01-21 16:35:12,398 (WARNING): You will not be able to see the stdout/stderr of the process in real time!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NbConvertApp] Writing 3241 bytes to 01_Intro.py\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2022-01-21 16:35:15,953 (INFO): Running stage 'RandomNumber':\n", + "> python3 -c \"from src.RandomNumber import RandomNumber; RandomNumber.load(name='RandomNumber').run_and_save()\" \n", + "Creating 'dvc.yaml'\n", + "Adding stage 'RandomNumber' in 'dvc.yaml'\n", + "Generating lock file 'dvc.lock'\n", + "Updating lock file 'dvc.lock'\n", + "\n", + "To track the changes with git, run:\n", + "\n", + " git add dvc.lock dvc.yaml nodes/RandomNumber/.gitignore\n", + "\n", + "To enable auto staging, run:\n", + "\n", + "\tdvc config core.autostage true\n", "\n" ] } @@ -204,7 +240,7 @@ "outputs": [ { "data": { - "text/plain": "473" + "text/plain": "125" }, "execution_count": 7, "metadata": {}, @@ -217,13 +253,10 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "nbsphinx": "hidden", - "tags": [], - "pycharm": { - "is_executing": true - } + "tags": [] }, "outputs": [], "source": [ diff --git a/tests/integration_tests/test_single_node.py b/tests/integration_tests/test_single_node.py index c45900e3..b1f00e2c 100644 --- a/tests/integration_tests/test_single_node.py +++ b/tests/integration_tests/test_single_node.py @@ -259,3 +259,21 @@ def test_dvc_deps_node(proj_path): assert DVCDepsNode.load(name="simple_test").dependency_file == pathlib.Path( "test_traj.txt" ) + + +class SingleNodeNoInit(Node): + param1 = zn.params() + param2 = zn.params() + + result = zn.outs() + + def run(self): + self.result = self.param1 + self.param2 + + +def test_auto_init(proj_path): + SingleNodeNoInit(param1=25, param2=42).write_graph(no_exec=False) + + assert SingleNodeNoInit.load().param1 == 25 + assert SingleNodeNoInit.load().param2 == 42 + assert SingleNodeNoInit.load().result == 25 + 42 diff --git a/tests/unit_tests/utlis/test_utils.py b/tests/unit_tests/utlis/test_utils.py index 5b54ca45..541fd7e9 100644 --- a/tests/unit_tests/utlis/test_utils.py +++ b/tests/unit_tests/utlis/test_utils.py @@ -12,9 +12,10 @@ import os import pathlib +import pytest import znjson -from zntrack.utils import cwd_temp_dir, decode_dict, is_jsonable +from zntrack.utils import utils def test_is_jsonable(): @@ -22,12 +23,12 @@ def test_is_jsonable(): Test is performed for a serializable dictionary and a non-serializable function. """ - assert is_jsonable({"a": 1}) is True - assert is_jsonable({"a": is_jsonable}) is False + assert utils.is_jsonable({"a": 1}) is True + assert utils.is_jsonable({"a": utils.is_jsonable}) is False def test_cwd_temp_dir(): - new_dir = cwd_temp_dir(required_files=[__file__]) + new_dir = utils.cwd_temp_dir(required_files=[__file__]) assert pathlib.Path(new_dir.name) == pathlib.Path(os.getcwd()) assert next(pathlib.Path(new_dir.name).glob("*.py")).name == "test_utils.py" os.chdir("..") @@ -39,5 +40,21 @@ def test_decode_dict_path(): dict_string = json.dumps(path, cls=znjson.ZnEncoder) loaded_dict = json.loads(dict_string) assert loaded_dict == {"_type": "pathlib.Path", "value": "test.txt"} - assert decode_dict(loaded_dict) == path - assert decode_dict(None) is None + assert utils.decode_dict(loaded_dict) == path + assert utils.decode_dict(None) is None + + +class Test: + pass + + +def test_get_auto_init(): + with pytest.raises(TypeError): + Test(foo="foo") + + setattr(Test, "__init__", utils.get_auto_init(fields=["foo", "bar"])) + + test = Test(foo="foo", bar="bar") + + assert test.foo == "foo" + assert test.bar == "bar" diff --git a/zntrack/core/base.py b/zntrack/core/base.py index d1e2a059..11eabc59 100644 --- a/zntrack/core/base.py +++ b/zntrack/core/base.py @@ -14,7 +14,8 @@ from zntrack.core.dvcgraph import GraphWriter from zntrack.utils.config import config -from zntrack.utils.utils import deprecated +from zntrack.utils.utils import deprecated, get_auto_init +from zntrack.zn import params log = logging.getLogger(__name__) @@ -49,6 +50,21 @@ def __call__(self, *args, **kwargs): """Still here for a depreciation warning for migrating to class based ZnTrack""" pass + def __init_subclass__(cls, **kwargs): + """Add a dataclass-like init if None is provided""" + + # User provides an __init__ + if cls.__dict__.get("__init__") is not None: + return cls + + # attach an automatically generated __init__ if None is provided + zn_option_fiels = [] + for name, item in cls.__dict__.items(): + if isinstance(item, params): + zn_option_fiels.append(name) + + setattr(cls, "__init__", get_auto_init(fields=zn_option_fiels)) + def save(self, results: bool = False): """Save Class state to files diff --git a/zntrack/utils/utils.py b/zntrack/utils/utils.py index 31772b7b..7715844e 100644 --- a/zntrack/utils/utils.py +++ b/zntrack/utils/utils.py @@ -14,6 +14,7 @@ import os import shutil import tempfile +import typing import znjson @@ -92,3 +93,24 @@ def wrapper(*args, **kwargs): def decode_dict(value): """Decode dict that was loaded without znjson""" return json.loads(json.dumps(value), cls=znjson.ZnDecoder) + + +def get_auto_init(fields: typing.List[str]): + """Automatically create a __init__ based on fields + Parameters + ---------- + fields: list[str] + A list of strings that will be used in the __init__, e.g. for [foo, bar] + it will create __init__(self, foo=None, bar=None) using **kwargs + """ + + def auto_init(self, **kwargs): + """Wrapper for the __init__""" + for field in fields: + try: + setattr(self, field, kwargs.pop(field)) + except KeyError: + pass + super(type(self), self).__init__(**kwargs) + + return auto_init