Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Node[<nodename>] to load a Node #254

Merged
merged 2 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/integration_tests/test_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,19 @@ def run(self):
def test_OutsNotWritten(proj_path):
with pytest.raises(DVCProcessError):
OutsNotWritten().write_graph(run=True)


def test_load_named_nodes(proj_path):
ExampleNode01(name="Node01", inputs=42).write_graph(run=True)
ExampleNode01(name="Node02", inputs=3.1415).write_graph(run=True)

assert ExampleNode01["Node01"].outputs == 42
assert ExampleNode01["Node02"].outputs == 3.1415

# this will run load with name=Node01, lazy=True/False
assert ExampleNode01[("Node01", True)].outputs == 42
assert ExampleNode01[("Node01", False)].outputs == 42

# this will run load with name=Node01, lazy=True/False
assert ExampleNode01[{"name": "Node01", "lazy": True}].outputs == 42
assert ExampleNode01[{"name": "Node01", "lazy": False}].outputs == 42
23 changes: 22 additions & 1 deletion zntrack/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,28 @@ def update_dependency_options(value):
value.update_options()


class Node(GraphWriter):
class LoadViaGetItem(type):
"""Metaclass for adding getitem support to load"""

def __getitem__(self: Node, item) -> Node:
"""Allow Node[<nodename>] to access an instance of the Node

Attributes
----------
item: str|tuple|dict
Can be a string, for load(name=item)
Can be a tuple for load(*item) | e.g. ("nodename", True)
Can be a dict for load(**item) | e.g. {name:"nodename", lazy:True}

"""
if isinstance(item, tuple):
return self.load(*item)
elif isinstance(item, dict):
return self.load(**item)
return self.load(name=item)


class Node(GraphWriter, metaclass=LoadViaGetItem):
"""Main parent class for all ZnTrack Node

The methods implemented in this class are primarily loading and saving parameters.
Expand Down
2 changes: 2 additions & 0 deletions zntrack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def get_python_interpreter() -> str:

except subprocess.CalledProcessError:
log.debug(f"{interpreter} is not working!")
except FileNotFoundError as err:
log.debug(err)
raise ValueError(
"Could not find a working python interpreter to work with subprocesses!"
)
Expand Down