Skip to content

Commit

Permalink
add typing to aiokafka/protocol/* (#999)
Browse files Browse the repository at this point in the history
* add typing to aiokafka/protocol/*

* fix review

* fix VarInt64

* fix review tuple -> list

* fix review

* fix review

* move ALL_TOPICS/NO_TOPICS to docs

* remove default values from Message()

* fix checking abstractproperty in test

* fix review

* fix review (from docstrings to comments)

* fix: collections.abc.Sequence -> typing.Sequence

* fix review: Message

* add FIXME

* fix review: Message

* use NotImplemented instead of False
  • Loading branch information
dimastbk authored Apr 21, 2024
1 parent 2bba153 commit 1862620
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 219 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ FORMATTED_AREAS=\
aiokafka/helpers.py \
aiokafka/structs.py \
aiokafka/util.py \
aiokafka/protocol/ \
tests/test_codec.py \
tests/test_helpers.py

Expand Down
12 changes: 8 additions & 4 deletions aiokafka/protocol/abstract.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import abc
from io import BytesIO
from typing import Generic, TypeVar

T = TypeVar("T")

class AbstractType(metaclass=abc.ABCMeta):

class AbstractType(Generic[T], metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
def encode(cls, value): ...
def encode(cls, value: T) -> bytes: ...

@classmethod
@abc.abstractmethod
def decode(cls, data): ...
def decode(cls, data: BytesIO) -> T: ...

@classmethod
def repr(cls, value):
def repr(cls, value: T) -> str:
return repr(value)
13 changes: 10 additions & 3 deletions aiokafka/protocol/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, Iterable, Optional, Tuple

from .api import Request, Response
from .types import (
Array,
Expand Down Expand Up @@ -429,8 +431,8 @@ class DescribeGroupsResponse_v3(Response):
("member_assignment", Bytes),
),
),
("authorized_operations", Int32),
),
("authorized_operations", Int32),
),
)

Expand Down Expand Up @@ -1119,7 +1121,7 @@ class DeleteGroupsRequest_v1(Request):
DeleteGroupsResponse = [DeleteGroupsResponse_v0, DeleteGroupsResponse_v1]


class DescribeClientQuotasResponse_v0(Request):
class DescribeClientQuotasResponse_v0(Response):
API_KEY = 48
API_VERSION = 0
SCHEMA = Schema(
Expand Down Expand Up @@ -1385,7 +1387,12 @@ class DeleteRecordsRequest_v2(Request):
("tags", TaggedFields),
)

def __init__(self, topics, timeout_ms, tags=None):
def __init__(
self,
topics: Iterable[Tuple[str, Iterable[Tuple[int, int]]]],
timeout_ms: int,
tags: Optional[Dict[int, bytes]] = None,
) -> None:
super().__init__(
[
(
Expand Down
77 changes: 50 additions & 27 deletions aiokafka/protocol/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import abc
from io import BytesIO
from typing import Any, ClassVar, Dict, Optional, Type, Union

from .struct import Struct
from .types import Array, Int16, Int32, Schema, String, TaggedFields
Expand All @@ -12,7 +16,9 @@ class RequestHeader_v0(Struct):
("client_id", String("utf-8")),
)

def __init__(self, request, correlation_id=0, client_id="aiokafka"):
def __init__(
self, request: Request, correlation_id: int = 0, client_id: str = "aiokafka"
) -> None:
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id
)
Expand All @@ -28,7 +34,13 @@ class RequestHeader_v1(Struct):
("tags", TaggedFields),
)

def __init__(self, request, correlation_id=0, client_id="aiokafka", tags=None):
def __init__(
self,
request: Request,
correlation_id: int = 0,
client_id: str = "aiokafka",
tags: Optional[Dict[int, bytes]] = None,
):
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {}
)
Expand All @@ -48,32 +60,38 @@ class ResponseHeader_v1(Struct):


class Request(Struct, metaclass=abc.ABCMeta):
FLEXIBLE_VERSION = False
FLEXIBLE_VERSION: ClassVar[bool] = False

@abc.abstractproperty
def API_KEY(self):
@property
@abc.abstractmethod
def API_KEY(self) -> int:
"""Integer identifier for api request"""

@abc.abstractproperty
def API_VERSION(self):
@property
@abc.abstractmethod
def API_VERSION(self) -> int:
"""Integer of api request version"""

@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the request structure"""

@abc.abstractproperty
def RESPONSE_TYPE(self):
@property
@abc.abstractmethod
def RESPONSE_TYPE(self) -> Type[Response]:
"""The Response class associated with the api request"""

def expect_response(self):
@property
@abc.abstractmethod
def SCHEMA(self) -> Schema:
"""An instance of Schema() representing the request structure"""

def expect_response(self) -> bool:
"""Override this method if an api request does not always generate a response"""
return True

def to_object(self):
def to_object(self) -> Dict[str, Any]:
return _to_object(self.SCHEMA, self)

def build_request_header(self, correlation_id, client_id):
def build_request_header(
self, correlation_id: int, client_id: str
) -> Union[RequestHeader_v0, RequestHeader_v1]:
if self.FLEXIBLE_VERSION:
return RequestHeader_v1(
self, correlation_id=correlation_id, client_id=client_id
Expand All @@ -82,31 +100,36 @@ def build_request_header(self, correlation_id, client_id):
self, correlation_id=correlation_id, client_id=client_id
)

def parse_response_header(self, read_buffer):
def parse_response_header(
self, read_buffer: Union[BytesIO, bytes]
) -> Union[ResponseHeader_v0, ResponseHeader_v1]:
if self.FLEXIBLE_VERSION:
return ResponseHeader_v1.decode(read_buffer)
return ResponseHeader_v0.decode(read_buffer)


class Response(Struct, metaclass=abc.ABCMeta):
@abc.abstractproperty
def API_KEY(self):
@property
@abc.abstractmethod
def API_KEY(self) -> int:
"""Integer identifier for api request/response"""

@abc.abstractproperty
def API_VERSION(self):
@property
@abc.abstractmethod
def API_VERSION(self) -> int:
"""Integer of api request/response version"""

@abc.abstractproperty
def SCHEMA(self):
@property
@abc.abstractmethod
def SCHEMA(self) -> Schema:
"""An instance of Schema() representing the response structure"""

def to_object(self):
def to_object(self) -> Dict[str, Any]:
return _to_object(self.SCHEMA, self)


def _to_object(schema, data):
obj = {}
def _to_object(schema: Schema, data: Union[Struct, Dict[int, Any]]) -> Dict[str, Any]:
obj: Dict[str, Any] = {}
for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)):
if isinstance(data, Struct):
val = data.get_item(name)
Expand All @@ -116,7 +139,7 @@ def _to_object(schema, data):
if isinstance(_type, Schema):
obj[name] = _to_object(_type, val)
elif isinstance(_type, Array):
if isinstance(_type.array_of, (Array, Schema)):
if isinstance(_type.array_of, Schema):
obj[name] = [_to_object(_type.array_of, x) for x in val]
else:
obj[name] = val
Expand Down
6 changes: 3 additions & 3 deletions aiokafka/protocol/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class FetchRequest_v7(Request):
),
(
"forgotten_topics_data",
Array(("topic", String), ("partitions", Array(Int32))),
Array(("topic", String("utf-8")), ("partitions", Array(Int32))),
),
)

Expand Down Expand Up @@ -428,7 +428,7 @@ class FetchRequest_v9(Request):
(
"forgotten_topics_data",
Array(
("topic", String),
("topic", String("utf-8")),
("partitions", Array(Int32)),
),
),
Expand Down Expand Up @@ -480,7 +480,7 @@ class FetchRequest_v11(Request):
),
(
"forgotten_topics_data",
Array(("topic", String), ("partitions", Array(Int32))),
Array(("topic", String("utf-8")), ("partitions", Array(Int32))),
),
("rack_id", String("utf-8")),
)
Expand Down
Loading

0 comments on commit 1862620

Please sign in to comment.