Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Inheritance Issue #267

Merged
merged 22 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 154 additions & 49 deletions examples/docs/02_PassingClasses.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,41 +54,147 @@
},
{
"cell_type": "markdown",
"id": "3d4248c9-df8d-477f-bef4-fee69324ec60",
"source": [
"## Inheriting from Base-Nodes\n",
"\n",
"The next part of the documentation will show how you can pass a Python class to a Node to enable different methods.\n",
"Whilst this can be very useful it is often easier to create a Base-Node and define custom methods as subclass of this Base.\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"id": "183551fb-2f96-49de-8276-bf4ea5afc53d",
"metadata": {},
"outputs": [],
"source": [
"from zntrack import Node, zn"
]
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"class NumberManipulationBase(Node):\n",
" node_name = \"NumberManipulationBase\"\n",
" # define the node_name for all child classes. Otherwise, child classes can coexist.\n",
" input_number = zn.params()\n",
" output_number = zn.outs()\n",
"\n",
"\n",
"class MultiplyNumber(NumberManipulationBase):\n",
" factor = zn.params()\n",
"\n",
" def run(self):\n",
" self.output_number = self.input_number * self.factor\n",
"\n",
"\n",
"class DivideNumber(NumberManipulationBase):\n",
" divider = zn.params()\n",
"\n",
" def run(self):\n",
" self.output_number = self.input_number / self.divider"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-03-26 18:47:08,865 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
"Submit issues to https://github.com/zincware/ZnTrack.\n",
"2022-03-26 18:47:11,353 (WARNING): Running DVC command: 'dvc run -n NumberManipulationBase ...'\n",
"30\n",
"2022-03-26 18:47:15,765 (WARNING): Running DVC command: 'dvc run -n NumberManipulationBase ...'\n",
"5.0\n"
]
}
],
"source": [
"MultiplyNumber(input_number=10, factor=3).write_graph(run=True)\n",
"print(MultiplyNumber.load().output_number)\n",
"\n",
"DivideNumber(input_number=10, divider=2).write_graph(run=True)\n",
"print(DivideNumber.load().output_number)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"Due to lazy-loading you might be able to access the output of `DivideNumber` also through `NumberManipulationBase` and `MultiplyNumber`.\n",
"This is only possible for shared ZnTrackOptions between the Nodes.\n",
"If you try to access e.g. the `factor` you will get an Error because `factor` is not an attribute of `DivideNumber`."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Creating Operations\n",
"\n",
"Best practice for adding different custom operations or methods is to inherit from a common parent with a method that does the computation."
]
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dd6ab631-01e4-4912-959e-006fad21e8b1",
"metadata": {},
"execution_count": 7,
"outputs": [],
"source": [
"class Base:\n",
" def compute(self, inp):\n",
" raise NotImplementedError"
]
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"id": "23828631-62dd-44a7-b91b-2cc0a1cf0a87",
"metadata": {},
"source": [
"For simplicity reasons we will look at some very simple functions but they can be of arbitrary complexity.\n",
"We apply the `check_signature` decorator which is an optional check that the tests that the keyword arguments are identical to the class attribute names.\n",
"This is mandatory for ZnTrack to work in the anticipated way."
]
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de1daf1f-f82e-4731-b1cb-19d0149c2916",
"metadata": {},
"execution_count": 8,
"outputs": [],
"source": [
"from zntrack.utils.decorators import check_signature\n",
Expand All @@ -110,29 +216,26 @@
"\n",
" def compute(self, inp):\n",
" return inp * self.factor"
]
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"id": "de03e1ef-635b-40e8-bf26-630025f588d9",
"metadata": {},
"source": [
"The actual Node makes use of the typical ZnTrack functionality beeing extended by `zn.Method()`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "183551fb-2f96-49de-8276-bf4ea5afc53d",
"metadata": {},
"outputs": [],
"source": [
"from zntrack import Node, zn"
]
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"id": "38d22553-641c-42b4-a99b-9d27de2ac41d",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -162,17 +265,15 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "35205014-bc4a-449e-8b48-e1936e35d985",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-02-22 13:36:04,657 (WARNING): Jupyter support is an experimental feature! Please save your notebook before running this command!\n",
"Submit issues to https://github.com/zincware/ZnTrack.\n",
"2022-02-22 13:36:08,305 (WARNING): Running DVC command: 'dvc run -n Calculator ...'\n"
"2022-03-26 18:47:20,381 (WARNING): Running DVC command: 'dvc run -n Calculator ...'\n"
]
}
],
Expand All @@ -190,15 +291,15 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"id": "9065e7e7-2d3a-4e39-8126-36a426eb0c2b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "15"
},
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -217,15 +318,15 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 12,
"id": "5986d32f-c6be-4d67-9b8d-e0ef6dd0bef8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-02-22 13:36:15,534 (WARNING): Running DVC command: 'dvc run -n Calculator ...'\n"
"2022-03-26 18:47:25,126 (WARNING): Running DVC command: 'dvc run -n Calculator ...'\n"
]
}
],
Expand All @@ -235,15 +336,15 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 13,
"id": "de686f58-74ae-4740-b624-561baa540389",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "20"
},
"execution_count": 11,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -262,7 +363,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "621c9cf0-1f69-4596-bfa5-e52754f77fd4",
"metadata": {},
"outputs": [],
Expand All @@ -279,15 +380,15 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "53052562-b1ce-4865-89f3-da5740208660",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-02-22 13:36:22,570 (WARNING): Running DVC command: 'dvc run -n Calculator ...'\n"
"2022-03-26 18:47:29,923 (WARNING): Running DVC command: 'dvc run -n Calculator ...'\n"
]
}
],
Expand All @@ -299,15 +400,15 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"id": "47670ed7-e235-4b1a-a7a4-37d5a48d79d7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "25"
},
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -326,7 +427,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"id": "f2a10182-0129-41f6-bbd0-0e21648364ec",
"metadata": {},
"outputs": [],
Expand All @@ -350,15 +451,15 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"id": "aec74308-8466-4509-a84c-b6dcd1dec8ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-02-22 13:36:29,476 (WARNING): Running DVC command: 'dvc run -n CombinedCalculator ...'\n"
"2022-03-26 18:47:34,537 (WARNING): Running DVC command: 'dvc run -n CombinedCalculator ...'\n"
]
}
],
Expand All @@ -370,15 +471,15 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "7e1b3b2b-79d3-4d79-88e8-26653c717520",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "25"
},
"execution_count": 17,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -389,11 +490,15 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"id": "6fc36519-35fd-4a3a-a55b-88618dace01b",
"metadata": {
"nbsphinx": "hidden",
"tags": []
"tags": [],
"pycharm": {
"name": "#%%\n",
"is_executing": true
}
},
"outputs": [],
"source": [
Expand Down
Loading