Skip to content

Commit

Permalink
fix pylint warnings (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ authored Sep 20, 2022
1 parent efa44bf commit ef90816
Show file tree
Hide file tree
Showing 15 changed files with 65 additions and 225 deletions.
82 changes: 0 additions & 82 deletions tests/unit_tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import sys
from unittest.mock import MagicMock, patch

import pytest
import znjson

from zntrack.utils import utils
Expand Down Expand Up @@ -37,87 +36,6 @@ def post_init(self):
self.text = f"{self.foo} {self.bar}"


def test_get_auto_init():
_ = EmptyCls()

with pytest.raises(TypeError):
# has no init
EmptyCls(foo="foo")

def set_init(lst, dct):
mock = MagicMock()
setattr(
EmptyCls,
"__init__",
utils.get_auto_init(
kwargs_no_default=lst, kwargs_with_default=dct, super_init=mock
),
)
return mock

# only none-default values
mock = set_init(["foo", "bar"], {})

with pytest.raises(TypeError):
# type error after setting the init
_ = EmptyCls()

test = EmptyCls(foo="foo", bar="bar")
assert test.foo == "foo"
assert test.bar == "bar"
mock.assert_called()

# only default values
mock = set_init([], {"foo": None, "bar": 10})
test = EmptyCls()
assert test.foo is None
assert test.bar == 10
mock.assert_called()

test = EmptyCls(foo="foo", bar="bar")
assert test.foo == "foo"
assert test.bar == "bar"

# mixed case
mock = set_init(["foo"], {"bar": 10})
with pytest.raises(TypeError):
_ = EmptyCls()

with pytest.raises(TypeError):
_ = EmptyCls(bar=20)

test = EmptyCls(foo="foo")
assert test.foo == "foo"
assert test.bar == 10

test = EmptyCls(foo="foo", bar="bar")
assert test.foo == "foo"
assert test.bar == "bar"
mock.assert_called()


def test_get_post_init():
with pytest.raises(TypeError):
ClsWithPostInit(foo="foo")

mock = MagicMock()
setattr(
ClsWithPostInit,
"__init__",
utils.get_auto_init(
kwargs_no_default=["foo", "bar"], kwargs_with_default={}, super_init=mock
),
)
test = ClsWithPostInit(foo="foo", bar="bar")

assert test.foo == "foo"
assert test.bar == "bar"
assert test.post_init
assert test.text == "foo bar"

mock.assert_called()


def test_module_handler():
my_mock = MagicMock
my_mock.__module__ = "custom_module"
Expand Down
4 changes: 4 additions & 0 deletions zntrack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""ZnTrack - Create, visualize, run & benchmark DVC pipelines in Python
GitHub: https://github.com/zincware/ZnTrack
"""
import importlib.metadata
import logging
import sys
Expand Down
12 changes: 7 additions & 5 deletions zntrack/core/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""ZnTrack Node class module"""
from __future__ import annotations

import json
Expand Down Expand Up @@ -72,13 +73,13 @@ def handle_deps(value) -> typing.List[str]:
return deps_files


BaseNodeType = typing.TypeVar("BaseNodeType", bound="Node")
BaseNodeTypeT = typing.TypeVar("BaseNodeTypeT", bound="Node")


class LoadViaGetItem(type):
"""Metaclass for adding getitem support to load"""

def __getitem__(cls: Node, item: typing.Union[str, dict]) -> BaseNodeType:
def __getitem__(cls: Node, item: typing.Union[str, dict]) -> BaseNodeTypeT:
"""Allow Node[<nodename>] to access an instance of the Node
Attributes
Expand All @@ -100,7 +101,7 @@ def __getitem__(cls: Node, item: typing.Union[str, dict]) -> BaseNodeType:
return cls.load(**item)
return cls.load(name=item)

def __matmul__(self, other: str) -> typing.Union[NodeAttribute, typing.Any]:
def __matmul__(cls, other: str) -> typing.Union[NodeAttribute, typing.Any]:
"""Shorthand for: getdeps(Node, other)
Parameters
Expand All @@ -116,7 +117,7 @@ def __matmul__(self, other: str) -> typing.Union[NodeAttribute, typing.Any]:
raise ValueError(
f"Can not compute 'Node @ {type(other)}'. Expected 'Node @ str'."
)
return getdeps(self, other)
return getdeps(cls, other)


class NodeBase(zninit.ZnInit):
Expand Down Expand Up @@ -175,7 +176,7 @@ class Node(NodeBase, metaclass=LoadViaGetItem):
"""

init_subclass_basecls = NodeBase
init_descriptors = [zn.params, zn.deps, zn.Method, zn.Nodes] + dvc.__all__
init_descriptors = [zn.params, zn.deps, zn.Method, zn.Nodes] + dvc.options

@utils.deprecated(
reason=(
Expand Down Expand Up @@ -430,6 +431,7 @@ def _handle_nodes_as_methods(self):

@property
def zntrack(self) -> ZnTrackInfo:
"""Get a ZnTrackInfo object"""
return ZnTrackInfo(parent=self)

@property
Expand Down
6 changes: 3 additions & 3 deletions zntrack/core/dvcgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def dvc_args(self) -> list:
["--no-commit", "--external"]
"""
out = []
for field_name in self.__dataclass_fields__:
value = getattr(self, field_name)
for datacls_field in dataclasses.fields(self):
value = getattr(self, datacls_field.name)
if value:
out.append(f"--{field_name.replace('_', '-')}")
out.append(f"--{datacls_field.name.replace('_', '-')}")
return out


Expand Down
34 changes: 17 additions & 17 deletions zntrack/core/functions/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ class NodeConfig:
plots_no_cache: UnionDictListOfStrPath = None

def __post_init__(self):
for option_name in self.__dataclass_fields__:
for datacls_field in dataclasses.fields(self):
# type checking
option_value = getattr(self, option_name)
if option_name == "params":
option_value = getattr(self, datacls_field.name)
if datacls_field.name == "params":
# params does not have to be a string
if not isinstance(option_value, dict) and option_value is not None:
raise ValueError("Parameter must be dict or dot4dict.dotdict.")
Expand All @@ -65,10 +65,10 @@ def __post_init__(self):

def convert_fields_to_dotdict(self):
"""Update all fields to dotdict, if they are of type dict"""
for option_name in self.__dataclass_fields__:
option_value = getattr(self, option_name)
for datacls_field in dataclasses.fields(self):
option_value = getattr(self, datacls_field.name)
if isinstance(option_value, dict):
setattr(self, option_name, dot4dict.dotdict(option_value))
setattr(self, datacls_field.name, dot4dict.dotdict(option_value))

def write_dvc_command(self, node_name: str) -> list:
"""Collect dvc commands
Expand All @@ -89,25 +89,25 @@ def write_dvc_command(self, node_name: str) -> list:
if self.params is not None:
if len(self.params) > 0:
script += ["--params", f"{utils.Files.params}:{node_name}"]
for field in self.__dataclass_fields__:
if field == "params":
for datacls_field in dataclasses.fields(self):
if datacls_field.name == "params":
continue
if isinstance(getattr(self, field), (list, tuple)):
for element in getattr(self, field):
if isinstance(getattr(self, datacls_field.name), (list, tuple)):
for element in getattr(self, datacls_field.name):
script += [
f"--{field.replace('_', '-')}",
f"--{datacls_field.name.replace('_', '-')}",
pathlib.Path(element).as_posix(),
]
elif isinstance(getattr(self, field), dict):
for element in getattr(self, field).values():
elif isinstance(getattr(self, datacls_field.name), dict):
for element in getattr(self, datacls_field.name).values():
script += [
f"--{field.replace('_', '-')}",
f"--{datacls_field.name.replace('_', '-')}",
pathlib.Path(element).as_posix(),
]
elif getattr(self, field) is not None:
elif getattr(self, datacls_field.name) is not None:
script += [
f"--{field.replace('_', '-')}",
pathlib.Path(getattr(self, field)).as_posix(),
f"--{datacls_field.name.replace('_', '-')}",
pathlib.Path(getattr(self, datacls_field.name)).as_posix(),
]

return script
Expand Down
22 changes: 11 additions & 11 deletions zntrack/dvc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# detailed explanations on https://dvc.org/doc/command-reference/run#options


class params(DVCOption):
class params(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -32,7 +32,7 @@ class params(DVCOption):
file = utils.Files.zntrack


class deps(DVCOption):
class deps(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -52,7 +52,7 @@ def __get__(self, instance, owner=None):
return value


class outs(DVCOption):
class outs(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -63,7 +63,7 @@ class outs(DVCOption):
file = utils.Files.zntrack


class checkpoints(DVCOption):
class checkpoints(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -74,7 +74,7 @@ class checkpoints(DVCOption):
file = utils.Files.zntrack


class outs_no_cache(DVCOption):
class outs_no_cache(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -85,7 +85,7 @@ class outs_no_cache(DVCOption):
file = utils.Files.zntrack


class outs_persistent(DVCOption):
class outs_persistent(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -96,7 +96,7 @@ class outs_persistent(DVCOption):
file = utils.Files.zntrack


class metrics(DVCOption):
class metrics(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -107,7 +107,7 @@ class metrics(DVCOption):
file = utils.Files.zntrack


class metrics_no_cache(DVCOption):
class metrics_no_cache(DVCOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -118,7 +118,7 @@ class metrics_no_cache(DVCOption):
file = utils.Files.zntrack


class plots(PlotsModifyOption):
class plots(PlotsModifyOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -129,7 +129,7 @@ class plots(PlotsModifyOption):
file = utils.Files.zntrack


class plots_no_cache(PlotsModifyOption):
class plots_no_cache(PlotsModifyOption): # pylint: disable=invalid-name
"""Identify DVC option
See https://dvc.org/doc/command-reference/run#options for more information
Expand All @@ -140,7 +140,7 @@ class plots_no_cache(PlotsModifyOption):
file = utils.Files.zntrack


__all__ = [
options = [
params,
deps,
outs,
Expand Down
1 change: 1 addition & 0 deletions zntrack/interface/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""DVC Interface through ZnTrack"""
import copy
import dataclasses
import json
Expand Down
1 change: 0 additions & 1 deletion zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
decode_dict,
deprecated,
encode_dict,
get_auto_init,
get_python_interpreter,
module_handler,
module_to_path,
Expand Down
5 changes: 4 additions & 1 deletion zntrack/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ def isnode(node) -> bool:
from zntrack.core.base import Node

if isinstance(node, (list, tuple)):
return any([isnode(x) for x in node])
for x in node:
if isnode(x):
return True
return False
else:
try:
if isinstance(node, Node) or issubclass(node, Node):
Expand Down
2 changes: 1 addition & 1 deletion zntrack/utils/nwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def replace_nwd_placeholder(
_path = path

try:
replace = True if nwd in _path else False
replace = nwd in _path
except TypeError:
# argument of type <Node> is not iterable. This can happen when you use
# e.g. dvc.deps(Node) (deprecated)
Expand Down
Loading

0 comments on commit ef90816

Please sign in to comment.