Skip to content

Commit

Permalink
Pylint perflint (#256)
Browse files Browse the repository at this point in the history
* add perflint

* fix some small linting issues

* add tests

* run pytest weekly

* change znjson level attribute

* remove pathlib converter because it is already enabled

* lint

* Add ValueError when Node is passed to zn.Method

* add `helpers.isnode`

* bugfix
  • Loading branch information
PythonFZ authored Mar 16, 2022
1 parent bdc0dee commit 3fd818b
Show file tree
Hide file tree
Showing 21 changed files with 148 additions and 40 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ jobs:
pip install .
- name: Install pylint
run: |
pip install pylint
pip install pylint perflint
- name: run pylint
continue-on-error: true
run: |
pylint zntrack
pylint zntrack --load-plugins=perflint
2 changes: 2 additions & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ name: pytest
on:
push:
pull_request:
schedule:
- cron: '14 3 * * 1' # at 03:14 on Monday.

jobs:
test:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ disable = [
"logging-fstring-interpolation",
"too-many-arguments",
"too-many-instance-attributes",
"dotted-import-in-loop",
# seems to fail for some cases
"no-else-return",
# allow for open TODOs
Expand Down
4 changes: 0 additions & 4 deletions tests/integration_tests/test_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,6 @@ def test_load_named_nodes(proj_path):
assert ExampleNode01["Node01"].outputs == 42
assert ExampleNode01["Node02"].outputs == 3.1415

# this will run load with name=Node01, lazy=True/False
assert ExampleNode01[("Node01", True)].outputs == 42
assert ExampleNode01[("Node01", False)].outputs == 42

# this will run load with name=Node01, lazy=True/False
assert ExampleNode01[{"name": "Node01", "lazy": True}].outputs == 42
assert ExampleNode01[{"name": "Node01", "lazy": False}].outputs == 42
23 changes: 23 additions & 0 deletions tests/integration_tests/test_zn_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,26 @@ def test_assert_read_file(proj_path, zntrack_dict):
pathlib.Path("zntrack.json").write_text(json.dumps(zntrack_dict))

assert isinstance(LastNode.load().first_node, FirstNode)


class FirstNodeParams(Node):
number: int = zn.params()


class SecondNodeParams(Node):
first_node_params: FirstNodeParams = zn.deps(FirstNodeParams)

negative_number: int = zn.params()

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.negative_number = -1 * self.first_node_params.number


@pytest.mark.parametrize("number", (5, -5))
def test_ParamsFromNodeNoLoad(proj_path, number):
FirstNodeParams(number=number).write_graph()
SecondNodeParams().write_graph()

assert FirstNodeParams.load().number == number
assert SecondNodeParams.load().negative_number == -1 * number
23 changes: 23 additions & 0 deletions tests/integration_tests/test_zn_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def run(self):
self.result = self.data_class.param1 + self.data_class.param2


class HelloWorld(Node):
param = zn.params(42)

def run(self):
pass


def test_run_twice_diff_params(proj_path):
SingleNode(data_class=ExampleMethod(1, 1)).write_graph(no_exec=False)
assert SingleNode.load().result == 2
Expand Down Expand Up @@ -224,3 +231,19 @@ def test_assert_read_files_old1(proj_path):
node = SingleNodeNoParams.load()
assert node.data_class.param1 == 1
assert node.data_class.param2 == 2


def test_err_node_as_method(proj_path):
HelloWorld().write_graph()

with pytest.raises(ValueError):
SingleNode(data_class=HelloWorld.load())

with pytest.raises(ValueError):
SingleNode(data_class=[HelloWorld.load(), HelloWorld.load()])

with pytest.raises(ValueError):
SingleNode(data_class=HelloWorld)

with pytest.raises(ValueError):
SingleNode(data_class=[HelloWorld, HelloWorld])
27 changes: 26 additions & 1 deletion tests/unit_tests/core/test_core_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import yaml

from zntrack import dvc, zn
from zntrack.core.base import Node, get_auto_init_signature, update_dependency_options
from zntrack.core.base import (
LoadViaGetItem,
Node,
get_auto_init_signature,
update_dependency_options,
)


class ExampleDVCOutsNode(Node):
Expand Down Expand Up @@ -222,3 +227,23 @@ def test_get_auto_init_signature():
assert signature_params[2].annotation is None

assert signature_params[-1].name == "out3"


class NodeMock(metaclass=LoadViaGetItem):
def load(self):
pass


@pytest.mark.parametrize(
("key", "called"),
[
("node", {"name": "node"}),
({"name": "node"}, {"name": "node"}),
({"name": "node", "lazy": True}, {"name": "node", "lazy": True}),
],
)
def test_LoadViaGetItem(key, called):
with patch.object(NodeMock, "load") as mock:
NodeMock[key]

mock.assert_called_with(**called)
24 changes: 24 additions & 0 deletions tests/unit_tests/utlis/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from zntrack import Node
from zntrack.utils import helpers


class MyNode(Node):
def run(self):
pass


@pytest.mark.parametrize(
("node", "true_val"),
[
(MyNode(), True),
(MyNode, True),
([MyNode, MyNode], True),
([MyNode(), MyNode()], True),
("Node", False),
(["Node", MyNode()], True),
],
)
def test_isnode(node, true_val):
assert helpers.isnode(node) == true_val
11 changes: 1 addition & 10 deletions zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,8 @@
from zntrack.utils.serializer import MethodConverter, ZnTrackTypeConverter

# register converters
znjson.config.ACTIVE_CONVERTER = [
ZnTrackTypeConverter,
znjson.PathlibConverter,
MethodConverter,
]
try:
znjson.register([znjson.NumpyConverter, znjson.SmallNumpyConverter])
except AttributeError:
pass
znjson.register([ZnTrackTypeConverter, MethodConverter])

#
__all__ = [
Node.__name__,
ZnTrackProject.__name__,
Expand Down
16 changes: 7 additions & 9 deletions zntrack/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,28 @@ def update_dependency_options(value):
the default_value Nodes, so we must to this manually here and call update_options.
"""
if isinstance(value, (list, tuple)):
[update_dependency_options(x) for x in value]
for item in value:
update_dependency_options(item)
if isinstance(value, Node):
value.update_options()


class LoadViaGetItem(type):
"""Metaclass for adding getitem support to load"""

def __getitem__(self: Node, item) -> Node:
def __getitem__(cls: Node, item) -> Node:
"""Allow Node[<nodename>] to access an instance of the Node
Attributes
----------
item: str|tuple|dict
item: str|dict
Can be a string, for load(name=item)
Can be a tuple for load(*item) | e.g. ("nodename", True)
Can be a dict for load(**item) | e.g. {name:"nodename", lazy:True}
"""
if isinstance(item, tuple):
return self.load(*item)
elif isinstance(item, dict):
return self.load(**item)
return self.load(name=item)
if isinstance(item, dict):
return cls.load(**item)
return cls.load(name=item)


class Node(GraphWriter, metaclass=LoadViaGetItem):
Expand Down
5 changes: 3 additions & 2 deletions zntrack/core/jupyter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging

log = logging.getLogger(__name__)
import pathlib
import re
import subprocess
from functools import lru_cache

from zntrack.utils.config import config

log = logging.getLogger(__name__)


@lru_cache(None)
def log_jupyter_warning():
Expand All @@ -30,6 +30,7 @@ def jupyter_class_to_file(nb_name, module_name):
subprocess.run(
["jupyter", "nbconvert", "--to", "script", nb_name],
capture_output=config.log_level > logging.INFO,
check=True,
)

reading_class = False
Expand Down
3 changes: 2 additions & 1 deletion zntrack/core/zntrackoption.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(self, default_value=None, **kwargs):
super().__init__(default_value=default_value, **kwargs)

@property
def dvc_args(self):
def dvc_args(self) -> str:
"""replace python variables '_' with '-' for dvc"""
return self.dvc_option.replace("_", "-")

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion zntrack/dvc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class deps(ZnTrackOption):

def __get__(self, instance, owner):
"""Use load_node_dependency before returning the value"""
value = super(deps, self).__get__(instance, owner)
value = super().__get__(instance, owner)
value = utils.utils.load_node_dependency(value, log_warning=True)
setattr(instance, self.name, value)
return value
Expand Down
3 changes: 1 addition & 2 deletions zntrack/metadata/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import re
from abc import ABC, abstractmethod
from functools import partial
from typing import Callable

from zntrack import utils
Expand Down Expand Up @@ -60,8 +61,6 @@ def __get__(self, instance, owner):
https://stackoverflow.com/questions/30104047/how-can-i-decorate-an-instance-method-with-a-decorator-class
"""

from functools import partial

return partial(self.__call__, instance)

def save_metadata(self, cls, value):
Expand Down
2 changes: 0 additions & 2 deletions zntrack/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,3 @@ class DescriptorMissing(Exception):

class DVCProcessError(Exception):
"""DVC specific message for CalledProcessError"""

pass
14 changes: 14 additions & 0 deletions zntrack/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from zntrack.core.base import Node


def isnode(node) -> bool:
"""Check if node contains a Node instance or class"""
if isinstance(node, (list, tuple)):
return any([isnode(x) for x in node])
else:
try:
if isinstance(node, Node) or issubclass(node, Node):
return True
except TypeError:
pass
return False
2 changes: 2 additions & 0 deletions zntrack/utils/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ZnTrackTypeConverter(znjson.ConverterBase):

instance = Node
representation = "ZnTrackType"
level = 10

def _encode(self, obj: Node) -> dict:
"""Convert Node to serializable dict"""
Expand All @@ -98,6 +99,7 @@ class MethodConverter(znjson.ConverterBase):
"""ZnJSON Converter for zn.method attributes"""

representation = "zn.method"
level = 10

def _encode(self, obj):
"""Serialize the object"""
Expand Down
6 changes: 2 additions & 4 deletions zntrack/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,12 @@ def check_type(
for value in obj:
if check_type(value, types, allow_iterable, allow_none, allow_dict):
continue
else:
return False
return False
elif isinstance(obj, dict) and allow_dict:
for value in obj.values():
if check_type(value, types, allow_iterable, allow_none, allow_dict):
continue
else:
return False
return False
else:
if allow_none and obj is None:
return True
Expand Down
2 changes: 1 addition & 1 deletion zntrack/zn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class deps(ZnTrackOption):

def __get__(self, instance, owner):
"""Use load_node_dependency before returning the value"""
value = super(deps, self).__get__(instance, owner)
value = super().__get__(instance, owner)
value = utils.utils.load_node_dependency(value)
setattr(instance, self.name, value)
return value
Expand Down
12 changes: 12 additions & 0 deletions zntrack/zn/method.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from zntrack import utils
from zntrack.utils import helpers
from zntrack.zn.split_option import SplitZnTrackOption

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -32,6 +33,17 @@ def get_filename(self, instance):
"""Does not have a single file but params.yaml and zntrack.json"""
return utils.Files.params, utils.Files.zntrack

def __set__(self, instance, value):
"""Include type check for better error reporting"""
# TODO raise error on default values,
# make compatible types an attribute of descriptor
if helpers.isnode(value):
raise ValueError(
f"zn.Method() does not support type <Node> ({value})."
" Consider using zn.deps() instead"
)
super().__set__(instance, value)

def __get__(self, instance, owner):
"""Add some custom attributes to the instance to identify it in znjson"""
if instance is None:
Expand Down
2 changes: 1 addition & 1 deletion zntrack/zn/split_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def split_value(input_val) -> (typing.Union[dict, list], typing.Union[dict, list
"""
if isinstance(input_val, (list, tuple)):
data = [split_value(x) for x in input_val]
params_data, zntrack_data = zip(*data)
params_data, _ = zip(*data)
else:
if input_val["_type"] in ["zn.method"]:
params_data = input_val["value"].pop("kwargs")
Expand Down

0 comments on commit 3fd818b

Please sign in to comment.