Skip to content

Commit

Permalink
add typing to aiokafka/protocol/*
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Apr 13, 2024
1 parent 1855cde commit f261e92
Show file tree
Hide file tree
Showing 11 changed files with 389 additions and 207 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ FORMATTED_AREAS=\
aiokafka/helpers.py \
aiokafka/structs.py \
aiokafka/util.py \
aiokafka/protocol/ \
tests/test_codec.py \
tests/test_helpers.py

Expand Down
12 changes: 8 additions & 4 deletions aiokafka/protocol/abstract.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import abc
from io import BytesIO
from typing import Generic, TypeVar

T = TypeVar("T")

class AbstractType(metaclass=abc.ABCMeta):

class AbstractType(Generic[T], metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
def encode(cls, value): ...
def encode(cls, value: T) -> bytes: ...

@classmethod
@abc.abstractmethod
def decode(cls, data): ...
def decode(cls, data: BytesIO) -> T: ...

@classmethod
def repr(cls, value):
def repr(cls, value: T) -> str:
return repr(value)
14 changes: 11 additions & 3 deletions aiokafka/protocol/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections.abc import Iterable
from typing import Dict, Optional, Tuple

from .api import Request, Response
from .types import (
Array,
Expand Down Expand Up @@ -429,8 +432,8 @@ class DescribeGroupsResponse_v3(Response):
("member_assignment", Bytes),
),
),
("authorized_operations", Int32),
),
("authorized_operations", Int32),
),
)

Expand Down Expand Up @@ -1119,7 +1122,7 @@ class DeleteGroupsRequest_v1(Request):
DeleteGroupsResponse = [DeleteGroupsResponse_v0, DeleteGroupsResponse_v1]


class DescribeClientQuotasResponse_v0(Request):
class DescribeClientQuotasResponse_v0(Response):
API_KEY = 48
API_VERSION = 0
SCHEMA = Schema(
Expand Down Expand Up @@ -1385,7 +1388,12 @@ class DeleteRecordsRequest_v2(Request):
("tags", TaggedFields),
)

def __init__(self, topics, timeout_ms, tags=None):
def __init__(
self,
topics: Iterable[Tuple[str, Iterable[Tuple[int, int]]]],
timeout_ms: int,
tags: Optional[Dict[int, bytes]] = None,
) -> None:
super().__init__(
[
(
Expand Down
87 changes: 61 additions & 26 deletions aiokafka/protocol/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import abc
from io import BytesIO
from typing import Any, ClassVar, Dict, Optional, Type, Union

from .struct import Struct
from .types import Array, Int16, Int32, Schema, String, TaggedFields
Expand All @@ -12,7 +14,9 @@ class RequestHeader_v0(Struct):
("client_id", String("utf-8")),
)

def __init__(self, request, correlation_id=0, client_id="aiokafka"):
def __init__(
self, request: "Request", correlation_id: int = 0, client_id: str = "aiokafka"
) -> None:
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id
)
Expand All @@ -28,7 +32,13 @@ class RequestHeader_v1(Struct):
("tags", TaggedFields),
)

def __init__(self, request, correlation_id=0, client_id="aiokafka", tags=None):
def __init__(
self,
request: "Request",
correlation_id: int = 0,
client_id: str = "aiokafka",
tags: Optional[Dict[int, bytes]] = None,
):
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {}
)
Expand All @@ -48,32 +58,46 @@ class ResponseHeader_v1(Struct):


class Request(Struct, metaclass=abc.ABCMeta):
FLEXIBLE_VERSION = False
FLEXIBLE_VERSION: ClassVar[bool] = False

@abc.abstractproperty
def API_KEY(self):
@property
@abc.abstractmethod
def API_KEY(self) -> int: # pyright:ignore[reportRedeclaration]
"""Integer identifier for api request"""

@abc.abstractproperty
def API_VERSION(self):
API_KEY: ClassVar[int] # type: ignore[no-redef]

@property
@abc.abstractmethod
def API_VERSION(self) -> int: # pyright:ignore[reportRedeclaration]
"""Integer of api request version"""

@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the request structure"""
API_VERSION: ClassVar[int] # type: ignore[no-redef]

@abc.abstractproperty
def RESPONSE_TYPE(self):
@property
@abc.abstractmethod
def RESPONSE_TYPE(self) -> Type["Response"]: # pyright:ignore[reportRedeclaration]
"""The Response class associated with the api request"""

def expect_response(self):
RESPONSE_TYPE: ClassVar[Type["Response"]] # type: ignore[no-redef]

@property
@abc.abstractmethod
def SCHEMA(self) -> Schema: # pyright:ignore[reportRedeclaration]
"""An instance of Schema() representing the request structure"""

SCHEMA: ClassVar[Schema] # type: ignore[no-redef]

def expect_response(self) -> bool:
"""Override this method if an api request does not always generate a response"""
return True

def to_object(self):
def to_object(self) -> Dict[str, Any]:
return _to_object(self.SCHEMA, self)

def build_request_header(self, correlation_id, client_id):
def build_request_header(
self, correlation_id: int, client_id: str
) -> Union[RequestHeader_v0, RequestHeader_v1]:
if self.FLEXIBLE_VERSION:
return RequestHeader_v1(
self, correlation_id=correlation_id, client_id=client_id
Expand All @@ -82,31 +106,42 @@ def build_request_header(self, correlation_id, client_id):
self, correlation_id=correlation_id, client_id=client_id
)

def parse_response_header(self, read_buffer):
def parse_response_header(
self, read_buffer: Union[BytesIO, bytes]
) -> Union[ResponseHeader_v0, ResponseHeader_v1]:
if self.FLEXIBLE_VERSION:
return ResponseHeader_v1.decode(read_buffer)
return ResponseHeader_v0.decode(read_buffer)


class Response(Struct, metaclass=abc.ABCMeta):
@abc.abstractproperty
def API_KEY(self):
@property
@abc.abstractmethod
def API_KEY(self) -> int: # pyright:ignore[reportRedeclaration]
"""Integer identifier for api request/response"""

@abc.abstractproperty
def API_VERSION(self):
API_KEY: ClassVar[int] # type: ignore[no-redef]

@property
@abc.abstractmethod
def API_VERSION(self) -> int: # pyright:ignore[reportRedeclaration]
"""Integer of api request/response version"""

@abc.abstractproperty
def SCHEMA(self):
API_VERSION: ClassVar[int] # type: ignore[no-redef]

@property
@abc.abstractmethod
def SCHEMA(self) -> Schema: # pyright:ignore[reportRedeclaration]
"""An instance of Schema() representing the response structure"""

def to_object(self):
SCHEMA: ClassVar[Schema] # type: ignore[no-redef]

def to_object(self) -> Dict[str, Any]:
return _to_object(self.SCHEMA, self)


def _to_object(schema, data):
obj = {}
def _to_object(schema: Schema, data: Union[Struct, Dict[int, Any]]) -> Dict[str, Any]:
obj: Dict[str, Any] = {}
for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)):
if isinstance(data, Struct):
val = data.get_item(name)
Expand All @@ -116,7 +151,7 @@ def _to_object(schema, data):
if isinstance(_type, Schema):
obj[name] = _to_object(_type, val)
elif isinstance(_type, Array):
if isinstance(_type.array_of, (Array, Schema)):
if isinstance(_type.array_of, Schema):
obj[name] = [_to_object(_type.array_of, x) for x in val]
else:
obj[name] = val
Expand Down
6 changes: 3 additions & 3 deletions aiokafka/protocol/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class FetchRequest_v7(Request):
),
(
"forgotten_topics_data",
Array(("topic", String), ("partitions", Array(Int32))),
Array(("topic", String("utf-8")), ("partitions", Array(Int32))),
),
)

Expand Down Expand Up @@ -428,7 +428,7 @@ class FetchRequest_v9(Request):
(
"forgotten_topics_data",
Array(
("topic", String),
("topic", String("utf-8")),
("partitions", Array(Int32)),
),
),
Expand Down Expand Up @@ -480,7 +480,7 @@ class FetchRequest_v11(Request):
),
(
"forgotten_topics_data",
Array(("topic", String), ("partitions", Array(Int32))),
Array(("topic", String("utf-8")), ("partitions", Array(Int32))),
),
("rack_id", String("utf-8")),
)
Expand Down
Loading

0 comments on commit f261e92

Please sign in to comment.