Skip to content

Commit

Permalink
Add mypy and more flake checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeweerd committed Feb 18, 2022
1 parent 515f4a1 commit 69081b1
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 52 deletions.
16 changes: 15 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,23 @@ repos:
hooks:
- id: flake8
entry: pflake8
additional_dependencies: ['pyproject-flake8==0.0.1a2']
additional_dependencies:
- pyproject-flake8==0.0.1a2
- flake8-bugbear==22.1.11
- flake8-comprehensions==3.8.0
- flake8_2020==1.6.1
- mccabe==0.6.1
- pycodestyle==2.8.0
- pyflakes==2.4.0

- repo: https://github.com/PyCQA/isort
rev: 5.10.1
hooks:
- id: isort

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
hooks:
- id: mypy
additional_dependencies:
- zigpy==0.43.0
20 changes: 20 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,23 @@ testing =

[coverage:run]
source = zigpy_znp

[flake8]
max-line-length = 88

[mypy]
ignore_missing_imports = True
install_types = True
non_interactive = True
check_untyped_defs = True
show_error_codes = True
show_error_context = True
disable_error_code =
attr-defined,
arg-type,
type-var,
var-annotated,
assignment,
call-overload,
name-defined,
union-attr
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# Python 3.8 already has this
from unittest.mock import AsyncMock as CoroutineMock # noqa: F401
except ImportError:
from asynctest import CoroutineMock # noqa: F401
from asynctest import CoroutineMock # type:ignore[no-redef] # noqa: F401

import zigpy.endpoint
import zigpy.zdo.types as zdo_t
Expand Down Expand Up @@ -69,7 +69,7 @@ def write(self, data):
assert self._is_connected
self.protocol.data_received(data)

def close(self, *, error=ValueError("Connection was closed")):
def close(self, *, error=ValueError("Connection was closed")): # noqa:
LOGGER.debug("Closing %s", self)
if not self._is_connected:
return
Expand Down
10 changes: 5 additions & 5 deletions tests/test_types_cstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,15 @@ class OldNIB(t.CStruct):
SecurityLevel: t.uint8_t
SymLink: t.uint8_t
CapabilityFlags: t.uint8_t
PaddingByte0: PaddingByte
PaddingByte0: PaddingByte # type:ignore[valid-type]
TransactionPersistenceTime: t.uint16_t
nwkProtocolVersion: t.uint8_t
RouteDiscoveryTime: t.uint8_t
RouteExpiryTime: t.uint8_t
PaddingByte1: PaddingByte
PaddingByte1: PaddingByte # type:ignore[valid-type]
nwkDevAddress: t.NWK
nwkLogicalChannel: t.uint8_t
PaddingByte2: PaddingByte
PaddingByte2: PaddingByte # type:ignore[valid-type]
nwkCoordAddress: t.NWK
nwkCoordExtAddress: t.EUI64
nwkPanId: t.uint16_t
Expand All @@ -330,11 +330,11 @@ class OldNIB(t.CStruct):
nwkConcentratorDiscoveryTime: t.uint8_t
nwkConcentratorRadius: t.uint8_t
nwkAllFresh: t.uint8_t
PaddingByte3: PaddingByte
PaddingByte3: PaddingByte # type:ignore[valid-type]
nwkManagerAddr: t.NWK
nwkTotalTransmissions: t.uint16_t
nwkUpdateId: t.uint8_t
PaddingByte4: PaddingByte
PaddingByte4: PaddingByte # type:ignore[valid-type]

nib = t.NIB(
SequenceNum=54,
Expand Down
38 changes: 28 additions & 10 deletions zigpy_znp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import itertools
import contextlib
import dataclasses
from typing import Union, overload
from collections import Counter, defaultdict

import zigpy.state
import async_timeout
import zigpy.zdo.types as zdo_t
from typing_extensions import Literal

import zigpy_znp.const as const
import zigpy_znp.types as t
Expand Down Expand Up @@ -50,8 +52,8 @@ def __init__(self, config: conf.ConfigType):
self._listeners = defaultdict(list)
self._sync_request_lock = asyncio.Lock()

self.capabilities = None
self.version = None
self.capabilities = None # type: int
self.version = None # type: float

self.nvram = NVRAMHelper(self)
self.network_info: zigpy.state.NetworkInformation = None
Expand Down Expand Up @@ -542,7 +544,7 @@ async def ping_task():

try:
async with async_timeout.timeout(CONNECT_PING_TIMEOUT):
result = await ping_task
result = await ping_task # type:ignore[misc]
except asyncio.TimeoutError:
ping_task.cancel()

Expand Down Expand Up @@ -609,7 +611,7 @@ def close(self) -> None:

self._app = None

for header, listeners in self._listeners.items():
for _header, listeners in self._listeners.items():
for listener in listeners:
listener.cancel()

Expand Down Expand Up @@ -659,7 +661,7 @@ def remove_listener(self, listener: BaseResponseListener) -> None:
counts[OneShotResponseListener],
)

def frame_received(self, frame: GeneralFrame) -> bool:
def frame_received(self, frame: GeneralFrame) -> bool | None:
"""
Called when a frame has been received. Returns whether or not the frame was
handled by any listener.
Expand All @@ -669,7 +671,7 @@ def frame_received(self, frame: GeneralFrame) -> bool:

if frame.header not in c.COMMANDS_BY_ID:
LOGGER.error("Received an unknown frame: %s", frame)
return
return None

command_cls = c.COMMANDS_BY_ID[frame.header]

Expand All @@ -680,7 +682,7 @@ def frame_received(self, frame: GeneralFrame) -> bool:
# https://github.com/home-assistant/core/issues/50005
if command_cls == c.ZDO.ParentAnnceRsp.Callback:
LOGGER.warning("Failed to parse broken %s as %s", frame, command_cls)
return
return None

raise

Expand Down Expand Up @@ -760,7 +762,21 @@ def callback_for_response(

return self.callback_for_responses([response], callback)

def wait_for_responses(self, responses, *, context=False) -> asyncio.Future:
@overload
def wait_for_responses(
self, responses, *, context: Literal[False] = ...
) -> asyncio.Future:
...

@overload
def wait_for_responses(
self, responses, *, context: Literal[True]
) -> tuple[asyncio.Future, OneShotResponseListener]:
...

def wait_for_responses(
self, responses, *, context: bool = False
) -> Union[asyncio.Future | tuple[asyncio.Future, OneShotResponseListener]]:
"""
Creates a one-shot listener that matches any *one* of the given responses.
"""
Expand All @@ -787,7 +803,9 @@ def wait_for_response(self, response: t.CommandBase) -> asyncio.Future:

return self.wait_for_responses([response])

async def request(self, request: t.CommandBase, **response_params) -> t.CommandBase:
async def request(
self, request: t.CommandBase, **response_params
) -> t.CommandBase | None:
"""
Sends a SREQ/AREQ request and returns its SRSP (only for SREQ), failing if any
of the SRSP's parameters don't match `response_params`.
Expand Down Expand Up @@ -827,7 +845,7 @@ async def request(self, request: t.CommandBase, **response_params) -> t.CommandB
if not request.Rsp:
LOGGER.debug("Request has no response, not waiting for one.")
self._uart.send(frame)
return
return None

# We need to create the response listener before we send the request
response_future = self.wait_for_responses(
Expand Down
5 changes: 4 additions & 1 deletion zigpy_znp/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import logging
import argparse
from typing import Optional, Sequence

import jsonschema
import coloredlogs
Expand Down Expand Up @@ -116,7 +117,9 @@ def validate_backup_json(backup: t.JSONType) -> None:


class CustomArgumentParser(argparse.ArgumentParser):
def parse_args(self, args: list[str] = None, namespace=None):
def parse_args(
self, args: Optional[Sequence[str]] = None, namespace=None
): # type:ignore[override]
args = super().parse_args(args, namespace)

# Since we're running as a CLI tool, install our own log level and color logger
Expand Down
42 changes: 23 additions & 19 deletions zigpy_znp/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def serialize_list(objects) -> Bytes:

class FixedIntType(int):
_signed = None
_size = None
_size = None # type:int

@classmethod
def __new__(cls, *args, **kwargs):
if cls._signed is None or cls._size is None:
raise TypeError(f"{cls} is abstract and cannot be created")
Expand All @@ -48,6 +49,7 @@ def __new__(cls, *args, **kwargs):

return instance

@classmethod
def __init_subclass__(cls, signed=None, size=None, hex_repr=None) -> None:
super().__init_subclass__()

Expand All @@ -58,7 +60,7 @@ def __init_subclass__(cls, signed=None, size=None, hex_repr=None) -> None:
cls._size = size

if hex_repr:
fmt = f"0x{{:0{cls._size * 2}X}}"
fmt = f"0x{{:0{cls._size * 2}X}}" # type:ignore[operator]
cls.__str__ = cls.__repr__ = lambda self: fmt.format(self)
elif hex_repr is not None and not hex_repr:
cls.__str__ = super().__str__
Expand All @@ -77,13 +79,15 @@ def serialize(self) -> bytes:
raise ValueError(str(e)) from e

@classmethod
def deserialize(cls, data: bytes) -> tuple[FixedIntType, bytes]:
if len(data) < cls._size:
def deserialize(
cls, data: bytes
) -> tuple[FixedIntType, bytes]: # type:ignore[return-value]
if len(data) < cls._size: # type:ignore[operator]
raise ValueError(f"Data is too short to contain {cls._size} bytes")

r = cls.from_bytes(data[: cls._size], "little", signed=cls._signed)
data = data[cls._size :]
return r, data
return r, data # type:ignore[return-value]


class uint_t(FixedIntType, signed=False):
Expand Down Expand Up @@ -162,7 +166,7 @@ class ShortBytes(Bytes):
_header = uint8_t

def serialize(self) -> Bytes:
return self._header(len(self)).serialize() + self
return self._header(len(self)).serialize() + self # type:ignore[return-value]

@classmethod
def deserialize(cls, data: bytes) -> tuple[Bytes, bytes]:
Expand All @@ -182,7 +186,7 @@ class BaseListType(list):
@classmethod
def _serialize_item(cls, item, *, align):
if not isinstance(item, cls._item_type):
item = cls._item_type(item)
item = cls._item_type(item) # type:ignore[misc]

if issubclass(cls._item_type, CStruct):
return item.serialize(align=align)
Expand Down Expand Up @@ -215,7 +219,7 @@ def serialize(self, *, align=False) -> bytes:
def deserialize(cls, data: bytes, *, align=False) -> tuple[LVList, bytes]:
length, data = cls._header.deserialize(data)
r = cls()
for i in range(length):
for _i in range(length):
item, data = cls._deserialize_item(data, align=align)
r.append(item)
return r, data
Expand All @@ -242,7 +246,7 @@ def serialize(self, *, align=False) -> bytes:
@classmethod
def deserialize(cls, data: bytes, *, align=False) -> tuple[FixedList, bytes]:
r = cls()
for i in range(cls._length):
for _i in range(cls._length):
item, data = cls._deserialize_item(data, align=align)
r.append(item)
return r, data
Expand Down Expand Up @@ -271,7 +275,7 @@ def enum_flag_factory(int_type: FixedIntType) -> enum.Flag:
appropriate methods but with only one non-Enum parent class.
"""

class _NewEnum(int_type, enum.Flag):
class _NewEnum(int_type, enum.Flag): # type:ignore[misc,valid-type]
# Rebind classmethods to our own class
_missing_ = classmethod(enum.IntFlag._missing_.__func__)
_create_pseudo_member_ = classmethod(
Expand All @@ -286,7 +290,7 @@ class _NewEnum(int_type, enum.Flag):
__rxor__ = enum.IntFlag.__rxor__
__invert__ = enum.IntFlag.__invert__

return _NewEnum
return _NewEnum # type:ignore[return-value]


class enum_uint8(uint8_t, enum.Enum):
Expand Down Expand Up @@ -321,33 +325,33 @@ class enum_uint64(uint64_t, enum.Enum):
pass


class enum_flag_uint8(enum_flag_factory(uint8_t)):
class enum_flag_uint8(enum_flag_factory(uint8_t)): # type:ignore[misc]
pass


class enum_flag_uint16(enum_flag_factory(uint16_t)):
class enum_flag_uint16(enum_flag_factory(uint16_t)): # type:ignore[misc]
pass


class enum_flag_uint24(enum_flag_factory(uint24_t)):
class enum_flag_uint24(enum_flag_factory(uint24_t)): # type:ignore[misc]
pass


class enum_flag_uint32(enum_flag_factory(uint32_t)):
class enum_flag_uint32(enum_flag_factory(uint32_t)): # type:ignore[misc]
pass


class enum_flag_uint40(enum_flag_factory(uint40_t)):
class enum_flag_uint40(enum_flag_factory(uint40_t)): # type:ignore[misc]
pass


class enum_flag_uint48(enum_flag_factory(uint48_t)):
class enum_flag_uint48(enum_flag_factory(uint48_t)): # type:ignore[misc]
pass


class enum_flag_uint56(enum_flag_factory(uint56_t)):
class enum_flag_uint56(enum_flag_factory(uint56_t)): # type:ignore[misc]
pass


class enum_flag_uint64(enum_flag_factory(uint64_t)):
class enum_flag_uint64(enum_flag_factory(uint64_t)): # type:ignore[misc]
pass
6 changes: 4 additions & 2 deletions zigpy_znp/types/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class Req(CommandBase, header=header, schema=definition.req_schema):
req_header = header
rsp_header = CommandHeader(0x0040 + req_header)

class Req(
class Req( # type:ignore[no-redef]
CommandBase, header=req_header, schema=definition.req_schema
):
pass
Expand Down Expand Up @@ -261,7 +261,9 @@ class Callback(
) # pragma: no cover

# If there is no request, this is a just a response
class Rsp(CommandBase, header=header, schema=definition.rsp_schema):
class Rsp( # type:ignore[no-redef]
CommandBase, header=header, schema=definition.rsp_schema
):
pass

Rsp.__qualname__ = qualname + ".Rsp"
Expand Down
Loading

0 comments on commit 69081b1

Please sign in to comment.