diff --git a/examples/docs/02_Inheritance.ipynb b/examples/docs/02_Inheritance.ipynb index 54c6b59d..966670c5 100644 --- a/examples/docs/02_Inheritance.ipynb +++ b/examples/docs/02_Inheritance.ipynb @@ -2,63 +2,73 @@ "cells": [ { "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "# Inheritance\n", "\n", "ZnTrack allows inheritance from a Node base class.\n", "This can e.g. be useful if you want to test out different methods of the same kind.\n", "In the following example, we will show this by using different functions in the run method with the same inputs and outputs." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } + ] }, { "cell_type": "code", "execution_count": 1, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "from zntrack import config\n", "\n", "config.nb_name = \"02_Inheritance.ipynb\"" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 2, + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Initialized empty Git repository in /tmp/tmp8jpf744o/.git/\r\n", - "Initialized DVC repository.\r\n", - "\r\n", - "You can now commit the changes to git.\r\n", - "\r\n", - "\u001B[31m+---------------------------------------------------------------------+\r\n", - "\u001B[0m\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n", - "\u001B[31m|\u001B[0m DVC has enabled anonymous aggregate usage analytics. \u001B[31m|\u001B[0m\r\n", - "\u001B[31m|\u001B[0m Read the analytics documentation (and how to opt-out) here: \u001B[31m|\u001B[0m\r\n", - "\u001B[31m|\u001B[0m <\u001B[36mhttps://dvc.org/doc/user-guide/analytics\u001B[39m> \u001B[31m|\u001B[0m\r\n", - "\u001B[31m|\u001B[0m \u001B[31m|\u001B[0m\r\n", - "\u001B[31m+---------------------------------------------------------------------+\r\n", - "\u001B[0m\r\n", - "\u001B[33mWhat's next?\u001B[39m\r\n", - "\u001B[33m------------\u001B[39m\r\n", - "- Check out the documentation: <\u001B[36mhttps://dvc.org/doc\u001B[39m>\r\n", - "- Get help and share ideas: <\u001B[36mhttps://dvc.org/chat\u001B[39m>\r\n", - "- Star us on GitHub: <\u001B[36mhttps://github.com/iterative/dvc\u001B[39m>\r\n", - "\u001B[0m" + "Initialized empty Git repository in C:/Users/fabia/AppData/Local/Temp/tmpc5a1k84s/.git/\n", + "Initialized DVC repository.\n", + "\n", + "You can now commit the changes to git.\n", + "\n", + "+---------------------------------------------------------------------+\n", + "| |\n", + "| DVC has enabled anonymous aggregate usage analytics. |\n", + "| Read the analytics documentation (and how to opt-out) here: |\n", + "| |\n", + "| |\n", + "+---------------------------------------------------------------------+\n", + "\n", + "What's next?\n", + "------------\n", + "- Check out the documentation: \n", + "- Get help and share ideas: \n", + "- Star us on GitHub: \n" ] } ], @@ -69,31 +79,37 @@ "\n", "!git init\n", "!dvc init" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "code", "execution_count": 3, - "outputs": [], - "source": [ - "from zntrack import Node, zn" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "from zntrack import Node, zn" + ] }, { "cell_type": "code", "execution_count": 4, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "class NodeBase(Node):\n", @@ -101,17 +117,20 @@ "\n", " inputs: float = zn.params()\n", " output: float = zn.outs()" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 5, + }, "outputs": [], "source": [ "class AddNumber(NodeBase):\n", @@ -127,41 +146,44 @@ "\n", " def run(self):\n", " self.output = self.inputs * self.factor" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 6, + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2022-05-17 09:23:32,127 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n", + "2022-09-16 16:16:46,480 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n", "Submit issues to https://github.com/zincware/ZnTrack.\n", - "2022-05-17 09:23:36,025 (WARNING): Running DVC command: 'dvc run -n basic_number ...'\n" + "2022-09-16 16:16:52,075 (WARNING): Running DVC command: 'dvc stage add -n basic_number ...'\n", + "2022-09-16 16:16:54,439 (WARNING): Running DVC command: 'dvc repro basic_number'\n" ] } ], "source": [ "add_number = AddNumber(inputs=10.0, offset=15.0)\n", "add_number.write_graph(run=True)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Because the Nodes inherit from each other and we defined the `node_name` in the parent class, we can use all classes to load the outputs (as long as they are shared).\n", "This is important to keep in mind when working with inheritance, that the output might not necessarily be created by the Node it was loaded by.\n", @@ -169,21 +191,26 @@ "A subsequent Node can e.g. depend on the parent Node and does not need to know where the values actually come from.\n", "I.e. an ML Model might implement a predict function in the parent node but can have an entirely different structure.\n", "An evaluation node might only need the predict method and can therefore be used with all children of the model class." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } + ] }, { "cell_type": "code", "execution_count": 7, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { - "text/plain": "25.0" + "text/plain": [ + "25.0" + ] }, "execution_count": 7, "metadata": {}, @@ -192,69 +219,80 @@ ], "source": [ "NodeBase.load().output" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 8, + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "+--------------+ \r\n", - "| basic_number | \r\n", - "+--------------+ \r\n", - "\u001B[0m" + "+--------------+ \n", + "| basic_number | \n", + "+--------------+ \n" ] } ], "source": [ "!dvc dag" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 9, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 9, + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2022-05-17 09:23:46,032 (WARNING): Running DVC command: 'dvc run -n basic_number ...'\n" + "2022-09-16 16:17:04,047 (WARNING): Running DVC command: 'dvc stage add -n basic_number ...'\n", + "2022-09-16 16:17:05,813 (WARNING): Running DVC command: 'dvc repro basic_number'\n" ] } ], "source": [ "multiply_number = MultiplyNumber(inputs=6.0, factor=6.0)\n", "multiply_number.write_graph(run=True)" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 10, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 10, + }, "outputs": [ { "data": { - "text/plain": "36.0" + "text/plain": [ + "36.0" + ] }, "execution_count": 10, "metadata": {}, @@ -263,53 +301,53 @@ ], "source": [ "NodeBase.load().output" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", - "source": [ - "As expected the node name remains the same and therefore, the Node is replaced with the new one." - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "As expected the node name remains the same and therefore, the Node is replaced with the new one." + ] }, { "cell_type": "code", "execution_count": 11, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "+--------------+ \r\n", - "| basic_number | \r\n", - "+--------------+ \r\n", - "\u001B[0m" + "+--------------+ \n", + "| basic_number | \n", + "+--------------+ \n" ] } ], "source": [ "!dvc dag" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "## Nodes as parameters\n", "\n", @@ -322,17 +360,20 @@ "To keep the DAG working, a `_hash = zn.Hash()` is introduced.\n", "This value is computed from the parameters as well as the current timestamp and only serves as a file dependency for DVC.\n", "Adding `zn.Hash()` to any Node will add an output file but won't have any additional effect." - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } + ] }, { "cell_type": "code", "execution_count": 12, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "class DivideNumber(NodeBase):\n", @@ -362,77 +403,94 @@ " self.output = self.value_handler.output\n", " # polynomials\n", " self.output = self.polynomial.a0 + self.polynomial.a1 * self.output" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 13, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 13, + }, "outputs": [], "source": [ "manipulate_number = ManipulateNumber(\n", " inputs=10.0,\n", - " value_handler=DivideNumber(divider=3.0),\n", + " value_handler=DivideNumber(divider=3.0, inputs=None),\n", " polynomial=Polynomial(a0=60.0, a1=10.0),\n", ")" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 14, + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2022-05-17 09:23:55,467 (WARNING): Running DVC command: 'dvc run -n ManipulateNumber-polynomial ...'\n", - "2022-05-17 09:24:02,692 (WARNING): Running DVC command: 'dvc run -n ManipulateNumber-value_handler ...'\n", - "2022-05-17 09:24:09,679 (WARNING): Running DVC command: 'dvc run -n ManipulateNumber ...'\n" + "2022-09-16 16:17:16,981 (WARNING): Running DVC command: 'dvc stage add -n ManipulateNumber-polynomial ...'\n", + "2022-09-16 16:17:18,814 (WARNING): Running DVC command: 'dvc repro ManipulateNumber-polynomial'\n", + "2022-09-16 16:17:26,646 (WARNING): Running DVC command: 'dvc stage add -n ManipulateNumber-value_handler ...'\n", + "2022-09-16 16:17:28,403 (WARNING): Running DVC command: 'dvc repro ManipulateNumber-value_handler'\n", + "2022-09-16 16:17:36,843 (WARNING): Running DVC command: 'dvc stage add -n ManipulateNumber ...'\n", + "2022-09-16 16:17:38,701 (WARNING): Running DVC command: 'dvc repro ManipulateNumber'\n" ] } ], "source": [ "manipulate_number.write_graph(run=True)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "code", "execution_count": 15, - "outputs": [], - "source": [ - "manipulate_number = manipulate_number.load()" - ], "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "manipulate_number = manipulate_number.load()" + ] }, { "cell_type": "code", "execution_count": 16, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "data": { - "text/plain": "360.0" + "text/plain": [ + "360.0" + ] }, "execution_count": 16, "metadata": {}, @@ -441,82 +499,81 @@ ], "source": [ "manipulate_number.output" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 17, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } - }, - { - "cell_type": "code", - "execution_count": 17, + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "+--------------+ \r\n", - "| basic_number | \r\n", - "+--------------+ \r\n", - "+-----------------------------+ +--------------------------------+\r\n", - "| ManipulateNumber-polynomial | | ManipulateNumber-value_handler |\r\n", - "+-----------------------------+ +--------------------------------+\r\n", - " **** ***** \r\n", - " **** **** \r\n", - " *** *** \r\n", - " +------------------+ \r\n", - " | ManipulateNumber | \r\n", - " +------------------+ \r\n", - "\u001B[0m" + "+--------------+ \n", + "| basic_number | \n", + "+--------------+ \n", + "+-----------------------------+ +--------------------------------+\n", + "| ManipulateNumber-polynomial | | ManipulateNumber-value_handler |\n", + "+-----------------------------+ +--------------------------------+\n", + " **** ***** \n", + " **** **** \n", + " *** *** \n", + " +------------------+ \n", + " | ManipulateNumber | \n", + " +------------------+ \n" ] } ], "source": [ "!dvc dag" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "code", - "execution_count": 18, - "outputs": [], - "source": [ - "temp_dir.cleanup()" - ], + "execution_count": null, "metadata": { "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "temp_dir.cleanup()" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.9.13" } }, "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file + "nbformat_minor": 4 +} diff --git a/tests/integration_tests/test_inheritance.py b/tests/integration_tests/test_inheritance.py index cb50b03f..591daac9 100644 --- a/tests/integration_tests/test_inheritance.py +++ b/tests/integration_tests/test_inheritance.py @@ -15,8 +15,8 @@ def run(self): class WriteDataWithInit(InOuts): def __init__(self, inputs=None, **kwargs): - super().__init__(**kwargs) - self.inputs = inputs + super().__init__(inputs=inputs, **kwargs) + # this calls the auto_init of the subclass which demands the inputs argument! def run(self): self.outputs = self.inputs diff --git a/tests/integration_tests/test_zn_nodes2.py b/tests/integration_tests/test_zn_nodes2.py index ef72bedb..c4ffeae6 100644 --- a/tests/integration_tests/test_zn_nodes2.py +++ b/tests/integration_tests/test_zn_nodes2.py @@ -49,7 +49,7 @@ def run(self): def test_SingleExampleNode(proj_path): - SingleExampleNode().write_graph(run=True) + SingleExampleNode(params1=None).write_graph(run=True) assert SingleExampleNode.load().outs == "Lorem Ipsum" @@ -79,7 +79,7 @@ def test_depth_graph(proj_path): node_3 = NodeNodeParams(deps=node_1, node=node_2, name="Node3") - node_4 = ExampleNode2(params1=node_3) + node_4 = ExampleNode2(params1=node_3, params2=None) node_4.write_graph(run=True) @@ -104,7 +104,7 @@ def run(self): def test_NodeWithOuts(proj_path): - node_1 = SingleExampleNode(params1=NodeWithOuts(factor=2)) + node_1 = SingleExampleNode(params1=NodeWithOuts(factor=2, input=None)) node_1.write_graph(run=True) assert SingleExampleNode.load().params1.factor == 2 diff --git a/tests/unit_tests/core/test_core_base.py b/tests/unit_tests/core/test_core_base.py index 7a2572db..2bf68b7f 100644 --- a/tests/unit_tests/core/test_core_base.py +++ b/tests/unit_tests/core/test_core_base.py @@ -135,17 +135,17 @@ def test_load(): default_correct_node = CorrectNode.load() assert default_correct_node.node_name == CorrectNode.__name__ - default_incorrect_node = InCorrectNode.load() - assert default_incorrect_node.node_name == InCorrectNode.__name__ + with pytest.raises(TypeError): + # can not load a Node that misses a correct super().__init__(**kwargs) + _ = InCorrectNode.load() + + with pytest.raises(TypeError): + _ = InCorrectNode["Test"] correct_node = CorrectNode.load(name="Test") assert correct_node.node_name == "Test" assert correct_node.test_name == correct_node.node_name - incorrect_node = InCorrectNode.load(name="Test") - assert incorrect_node.node_name == "Test" - assert incorrect_node.test_name != incorrect_node.node_name - class RunTestNode(Node): outs = zn.outs() @@ -231,17 +231,12 @@ class CollectionChild(ZnTrackOptionCollection): @pytest.mark.parametrize("cls", (ZnTrackOptionCollection, CollectionChild)) def test_get_auto_init_signature(cls): - zn_option_names, signature_params = get_auto_init_signature(cls) - - assert zn_option_names == ["out1", "out2", "out3", "param1", "param2", "param3"] + kwargs_no_default, kwargs_with_default, signature_params = get_auto_init_signature( + cls + ) - # assert signature_params[0].name == "out1" - # - # assert signature_params[3].name == "param1" - # assert signature_params[3].annotation == dict - # - # assert signature_params[5].name == "param3" - # assert signature_params[5].annotation is None + assert kwargs_no_default == ["out1", "out2", "out3", "param1", "param3"] + assert kwargs_with_default == {"param2": [1, 2]} class NodeMock(metaclass=LoadViaGetItem): diff --git a/tests/unit_tests/utils/test_utils.py b/tests/unit_tests/utils/test_utils.py index 107986a7..4a6ada74 100644 --- a/tests/unit_tests/utils/test_utils.py +++ b/tests/unit_tests/utils/test_utils.py @@ -27,41 +27,88 @@ def test_decode_dict_path(): assert utils.decode_dict(None) is None -class Test: +class EmptyCls: pass -class TestWithPostInit: +class ClsWithPostInit: def post_init(self): self.post_init = True self.text = f"{self.foo} {self.bar}" def test_get_auto_init(): + _ = EmptyCls() + with pytest.raises(TypeError): - Test(foo="foo") + # has no init + EmptyCls(foo="foo") + + def set_init(lst, dct): + mock = MagicMock() + setattr( + EmptyCls, + "__init__", + utils.get_auto_init( + kwargs_no_default=lst, kwargs_with_default=dct, super_init=mock + ), + ) + return mock + + # only none-default values + mock = set_init(["foo", "bar"], {}) - mock = MagicMock() - setattr(Test, "__init__", utils.get_auto_init(fields=["foo", "bar"], super_init=mock)) - test = Test(foo="foo", bar="bar") + with pytest.raises(TypeError): + # type error after setting the init + _ = EmptyCls() + + test = EmptyCls(foo="foo", bar="bar") + assert test.foo == "foo" + assert test.bar == "bar" + mock.assert_called() + + # only default values + mock = set_init([], {"foo": None, "bar": 10}) + test = EmptyCls() + assert test.foo is None + assert test.bar == 10 + mock.assert_called() + test = EmptyCls(foo="foo", bar="bar") assert test.foo == "foo" assert test.bar == "bar" + # mixed case + mock = set_init(["foo"], {"bar": 10}) + with pytest.raises(TypeError): + _ = EmptyCls() + + with pytest.raises(TypeError): + _ = EmptyCls(bar=20) + + test = EmptyCls(foo="foo") + assert test.foo == "foo" + assert test.bar == 10 + + test = EmptyCls(foo="foo", bar="bar") + assert test.foo == "foo" + assert test.bar == "bar" mock.assert_called() def test_get_post_init(): with pytest.raises(TypeError): - TestWithPostInit(foo="foo") + ClsWithPostInit(foo="foo") mock = MagicMock() setattr( - TestWithPostInit, + ClsWithPostInit, "__init__", - utils.get_auto_init(fields=["foo", "bar"], super_init=mock), + utils.get_auto_init( + kwargs_no_default=["foo", "bar"], kwargs_with_default={}, super_init=mock + ), ) - test = TestWithPostInit(foo="foo", bar="bar") + test = ClsWithPostInit(foo="foo", bar="bar") assert test.foo == "foo" assert test.bar == "bar" diff --git a/tests/unit_tests/zn/test_zn_nodes.py b/tests/unit_tests/zn/test_zn_nodes.py index 31920e04..c66a63f1 100644 --- a/tests/unit_tests/zn/test_zn_nodes.py +++ b/tests/unit_tests/zn/test_zn_nodes.py @@ -43,4 +43,3 @@ def test_require_hash(): _ = ExampleNode(example=(ParamsNodeWithHash(),)) _ = ExampleNode(example=None) # allow None type - _ = ExampleNode() diff --git a/zntrack/core/base.py b/zntrack/core/base.py index 1ef93985..8a3fe140 100644 --- a/zntrack/core/base.py +++ b/zntrack/core/base.py @@ -13,10 +13,20 @@ log = logging.getLogger(__name__) -def get_auto_init_signature(cls) -> (list, list): +def get_auto_init_signature(cls) -> (list, dict, list): """Iterate over ZnTrackOptions in the __dict__ and save the option name - and create a signature Parameter""" - zn_option_names, signature_params = [], [] + and create a signature Parameter + + Returns: + kwargs_no_default: list + a list of names that will be converted to kwargs + kwargs_with_default: dict + a dict of {name: default_value} that will be converted to kwargs + signature_params: inspect.Parameter + """ + signature_params = [] + kwargs_no_default = [] + kwargs_with_default = {} _ = cls.__annotations__ # fix for https://bugs.python.org/issue46930 descriptors = get_descriptors(ZnTrackOption, cls=cls) for descriptor in descriptors: @@ -24,17 +34,21 @@ def get_auto_init_signature(cls) -> (list, list): # exclude zn.outs / metrics / plots / ... options continue # For the new __init__ - zn_option_names.append(descriptor.name) + if descriptor.default_value is None: + kwargs_no_default.append(descriptor.name) + else: + kwargs_with_default[descriptor.name] = descriptor.default_value # For the new __signature__ signature_params.append( inspect.Parameter( + # default=... name=descriptor.name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=cls.__annotations__.get(descriptor.name), ) ) - return zn_option_names, signature_params + return kwargs_no_default, kwargs_with_default, signature_params def update_dependency_options(value): @@ -176,13 +190,19 @@ def __init_subclass__(cls, **kwargs): return cls # attach an automatically generated __init__ if None is provided - zn_option_names, signature_params = get_auto_init_signature(cls) + ( + kwargs_no_default, + kwargs_with_default, + signature_params, + ) = get_auto_init_signature(cls) # Add new __init__ to the subclass setattr( cls, "__init__", - utils.get_auto_init(fields=zn_option_names, super_init=Node.__init__), + utils.get_auto_init( + kwargs_no_default, kwargs_with_default, super_init=Node.__init__ + ), ) # Add new __signature__ to the subclass @@ -295,24 +315,27 @@ def __init__(self, **kwargs): try: instance = cls(name=name, is_loaded=True) except TypeError as type_error: - try: - instance = cls() - if name not in (None, cls.__name__): - instance.node_name = name - log.warning( - "Can not pass to the super.__init__ and trying workaround!" - " This can lead to unexpected behaviour and can be avoided by" - " passing ( **kwargs) to the super().__init__(**kwargs) - Received" - f" '{type_error}'" - ) - except TypeError as err: + if getattr(cls.__init__, "_uses_auto_init", False): + # using new + init from Node class to circumvent required + # arguments in the automatic init + instance = object.__new__(cls) + Node.__init__(instance, name=name, is_loaded=True) + else: + # when not using the automatic init all arguments must have a + # default value and the super call is required. It would still + # be raise TypeError( f"Unable to create a new instance of {cls}. Check that all arguments" " default to None. It must be possible to instantiate the class via" - f" {cls}() without passing any arguments. See the ZnTrack" - " documentation for more information." - ) from err - + f" {cls}() without passing any arguments. Furthermore, the" + " '**kwargs' must be passed to the 'super().__init__(**kwargs)'" + "See the ZnTrack documentation for more information." + ) from type_error + + assert instance.node_name is not None, ( + "The name of the Node is not set. Probably missing" + " 'super().__init__(**kwargs)' inside the custom '__init__'." + ) instance._update_options(lazy=lazy) if utils.config.nb_name is not None: diff --git a/zntrack/core/zntrackoption.py b/zntrack/core/zntrackoption.py index 7037c7a6..1c3c1503 100644 --- a/zntrack/core/zntrackoption.py +++ b/zntrack/core/zntrackoption.py @@ -194,7 +194,11 @@ def mkdir(self, instance): file = self.get_filename(instance) file.parent.mkdir(exist_ok=True, parents=True) - def _raise_loading_errors(self, instance, err): + def _get_loading_errors( + self, instance + ) -> typing.Union[ + utils.exceptions.DataNotAvailableError, utils.exceptions.GraphNotAvailableError + ]: """Raise specific errors when reading ZnTrackOptions Raises @@ -205,13 +209,13 @@ def _raise_loading_errors(self, instance, err): if the graph does not exist in dvc.yaml. This has higher priority """ if instance._graph_entry_exists: - raise utils.exceptions.DataNotAvailableError( + return utils.exceptions.DataNotAvailableError( f"Could not load data for '{self.name}' from file." - ) from err - raise utils.exceptions.GraphNotAvailableError( + ) + return utils.exceptions.GraphNotAvailableError( f"Could not find the graph configuration for '{instance.node_name}' in" f" {utils.Files.dvc}." - ) from err + ) def get_data_from_files(self, instance): """Load the value/s for the given instance from the file/s @@ -232,7 +236,7 @@ def get_data_from_files(self, instance): try: file_content = utils.file_io.read_file(file) except FileNotFoundError as err: - self._raise_loading_errors(instance, err) + raise self._get_loading_errors(instance) from err # The problem here is, that I can not / don't want to load all Nodes but # only the ones, that are in [self.node_name][self.name] for deserializing try: @@ -241,6 +245,6 @@ def get_data_from_files(self, instance): else: values = utils.decode_dict(file_content[self.name]) except KeyError as err: - self._raise_loading_errors(instance, err) + raise self._get_loading_errors(instance) from err log.debug(f"Loading {instance.node_name} from {file}: ({values})") return values diff --git a/zntrack/utils/utils.py b/zntrack/utils/utils.py index 5901befc..ded39d3f 100644 --- a/zntrack/utils/utils.py +++ b/zntrack/utils/utils.py @@ -72,24 +72,72 @@ def encode_dict(value) -> dict: return json.loads(json.dumps(value, cls=znjson.ZnEncoder)) -def get_auto_init(fields: typing.List[str], super_init: typing.Callable): +def get_init_type_error(required_keys, uses_super: bool = False) -> TypeError: + """Get a TypeError similar to a wrong __init__""" + if len(required_keys) == 1: + if uses_super: + return TypeError( + f"__init__() missing 1 required positional argument: '{required_keys[0]}'" + ) + return TypeError( + "super().__init__() missing 1 required positional argument:" + f" '{required_keys[0]}'" + ) + if len(required_keys) > 1: + if uses_super: + return TypeError( + f"__init__() missing {len(required_keys)} required positional arguments:" + f""" '{"', '".join(required_keys[:-1])}' and '{required_keys[-1]}'""" + ) + return TypeError( + f"super().__init__() missing {len(required_keys)} required positional" + " arguments:" + f""" '{"', '".join(required_keys[:-1])}' and '{required_keys[-1]}'""" + ) + + +def get_auto_init( + kwargs_no_default: typing.List[str], + kwargs_with_default: dict, + super_init: typing.Callable, +): """Automatically create an __init__ based on fields Parameters ---------- - fields: list[str] - A list of strings that will be used in the __init__, e.g. for [foo, bar] - it will create __init__(self, foo=None, bar=None) using **kwargs + kwargs_no_default: list[str] + A list that strings (required kwarg without default value) that will be used in + the __init__, e.g. for [foo, bar] will create __init__(self, foo, bar) + kwargs_with_default: dict[str, any] + A dict that contains the name of the kwarg as key and the default value + (kwargs with default value) that will be used in + the __init__, e.g. for {foo: None, bar: 10} will create + __init__(self, foo=None, bar=10) super_init: Callable typically this is Node.__init__ """ + kwargs_no_default = [] if kwargs_no_default is None else kwargs_no_default + kwargs_with_default = {} if kwargs_with_default is None else kwargs_with_default + def auto_init(self, **kwargs): """Wrapper for the __init__""" init_kwargs = {} - for field in fields: + required_keys = [] + self_uses_auto_init = getattr(self.__init__, "_uses_auto_init", False) + log.debug(f"The '__init__' uses auto_init: {self_uses_auto_init}") + for kwarg_name in kwargs_no_default: + try: + init_kwargs[kwarg_name] = kwargs.pop(kwarg_name) + except KeyError: + required_keys.append(kwarg_name) + + if len(required_keys) > 0: + raise get_init_type_error(required_keys, self_uses_auto_init) + + for kwarg_name, kwarg_value in kwargs_with_default.items(): try: - init_kwargs[field] = kwargs.pop(field) + init_kwargs[kwarg_name] = kwargs.pop(kwarg_name, kwarg_value) except KeyError: pass super_init(self, **kwargs) # call the super_init explicitly instead of super