Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add typing to aiokafka/protocol/* #999

Merged
merged 16 commits into from
Apr 21, 2024
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ FORMATTED_AREAS=\
aiokafka/helpers.py \
aiokafka/structs.py \
aiokafka/util.py \
aiokafka/protocol/ \
tests/test_codec.py \
tests/test_helpers.py

Expand Down
12 changes: 8 additions & 4 deletions aiokafka/protocol/abstract.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import abc
from io import BytesIO
from typing import Generic, TypeVar

T = TypeVar("T")

class AbstractType(metaclass=abc.ABCMeta):

class AbstractType(Generic[T], metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
def encode(cls, value): ...
def encode(cls, value: T) -> bytes: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@classmethod
@abc.abstractmethod
def decode(cls, data): ...
def decode(cls, data: BytesIO) -> T: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@classmethod
def repr(cls, value):
def repr(cls, value: T) -> str:
return repr(value)
13 changes: 10 additions & 3 deletions aiokafka/protocol/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict, Iterable, Optional, Tuple

from .api import Request, Response
from .types import (
Array,
Expand Down Expand Up @@ -429,8 +431,8 @@ class DescribeGroupsResponse_v3(Response):
("member_assignment", Bytes),
),
),
("authorized_operations", Int32),
),
("authorized_operations", Int32),
dimastbk marked this conversation as resolved.
Show resolved Hide resolved
),
)

Expand Down Expand Up @@ -1119,7 +1121,7 @@ class DeleteGroupsRequest_v1(Request):
DeleteGroupsResponse = [DeleteGroupsResponse_v0, DeleteGroupsResponse_v1]


class DescribeClientQuotasResponse_v0(Request):
class DescribeClientQuotasResponse_v0(Response):
API_KEY = 48
API_VERSION = 0
SCHEMA = Schema(
Expand Down Expand Up @@ -1385,7 +1387,12 @@ class DeleteRecordsRequest_v2(Request):
("tags", TaggedFields),
)

def __init__(self, topics, timeout_ms, tags=None):
def __init__(
self,
topics: Iterable[Tuple[str, Iterable[Tuple[int, int]]]],
timeout_ms: int,
tags: Optional[Dict[int, bytes]] = None,
) -> None:
super().__init__(
[
(
Expand Down
77 changes: 50 additions & 27 deletions aiokafka/protocol/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import abc
from io import BytesIO
from typing import Any, ClassVar, Dict, Optional, Type, Union

from .struct import Struct
from .types import Array, Int16, Int32, Schema, String, TaggedFields
Expand All @@ -12,7 +16,9 @@ class RequestHeader_v0(Struct):
("client_id", String("utf-8")),
)

def __init__(self, request, correlation_id=0, client_id="aiokafka"):
def __init__(
self, request: Request, correlation_id: int = 0, client_id: str = "aiokafka"
) -> None:
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id
)
Expand All @@ -28,7 +34,13 @@ class RequestHeader_v1(Struct):
("tags", TaggedFields),
)

def __init__(self, request, correlation_id=0, client_id="aiokafka", tags=None):
def __init__(
self,
request: Request,
correlation_id: int = 0,
client_id: str = "aiokafka",
tags: Optional[Dict[int, bytes]] = None,
):
super().__init__(
request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {}
)
Expand All @@ -48,32 +60,38 @@ class ResponseHeader_v1(Struct):


class Request(Struct, metaclass=abc.ABCMeta):
FLEXIBLE_VERSION = False
FLEXIBLE_VERSION: ClassVar[bool] = False

@abc.abstractproperty
def API_KEY(self):
@property
@abc.abstractmethod
def API_KEY(self) -> int:
"""Integer identifier for api request"""

@abc.abstractproperty
def API_VERSION(self):
@property
@abc.abstractmethod
def API_VERSION(self) -> int:
"""Integer of api request version"""

@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the request structure"""

@abc.abstractproperty
def RESPONSE_TYPE(self):
@property
@abc.abstractmethod
def RESPONSE_TYPE(self) -> Type[Response]:
"""The Response class associated with the api request"""

def expect_response(self):
@property
@abc.abstractmethod
def SCHEMA(self) -> Schema:
"""An instance of Schema() representing the request structure"""

def expect_response(self) -> bool:
"""Override this method if an api request does not always generate a response"""
return True

def to_object(self):
def to_object(self) -> Dict[str, Any]:
return _to_object(self.SCHEMA, self)

def build_request_header(self, correlation_id, client_id):
def build_request_header(
self, correlation_id: int, client_id: str
) -> Union[RequestHeader_v0, RequestHeader_v1]:
if self.FLEXIBLE_VERSION:
return RequestHeader_v1(
self, correlation_id=correlation_id, client_id=client_id
Expand All @@ -82,31 +100,36 @@ def build_request_header(self, correlation_id, client_id):
self, correlation_id=correlation_id, client_id=client_id
)

def parse_response_header(self, read_buffer):
def parse_response_header(
self, read_buffer: Union[BytesIO, bytes]
) -> Union[ResponseHeader_v0, ResponseHeader_v1]:
if self.FLEXIBLE_VERSION:
return ResponseHeader_v1.decode(read_buffer)
return ResponseHeader_v0.decode(read_buffer)


class Response(Struct, metaclass=abc.ABCMeta):
@abc.abstractproperty
def API_KEY(self):
@property
@abc.abstractmethod
def API_KEY(self) -> int:
"""Integer identifier for api request/response"""

@abc.abstractproperty
def API_VERSION(self):
@property
@abc.abstractmethod
def API_VERSION(self) -> int:
"""Integer of api request/response version"""

@abc.abstractproperty
def SCHEMA(self):
@property
@abc.abstractmethod
def SCHEMA(self) -> Schema:
"""An instance of Schema() representing the response structure"""

def to_object(self):
def to_object(self) -> Dict[str, Any]:
return _to_object(self.SCHEMA, self)


def _to_object(schema, data):
obj = {}
def _to_object(schema: Schema, data: Union[Struct, Dict[int, Any]]) -> Dict[str, Any]:
obj: Dict[str, Any] = {}
for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)):
if isinstance(data, Struct):
val = data.get_item(name)
Expand All @@ -116,7 +139,7 @@ def _to_object(schema, data):
if isinstance(_type, Schema):
obj[name] = _to_object(_type, val)
elif isinstance(_type, Array):
if isinstance(_type.array_of, (Array, Schema)):
if isinstance(_type.array_of, Schema):
dimastbk marked this conversation as resolved.
Show resolved Hide resolved
obj[name] = [_to_object(_type.array_of, x) for x in val]
else:
obj[name] = val
Expand Down
6 changes: 3 additions & 3 deletions aiokafka/protocol/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class FetchRequest_v7(Request):
),
(
"forgotten_topics_data",
Array(("topic", String), ("partitions", Array(Int32))),
Array(("topic", String("utf-8")), ("partitions", Array(Int32))),
ods marked this conversation as resolved.
Show resolved Hide resolved
),
)

Expand Down Expand Up @@ -428,7 +428,7 @@ class FetchRequest_v9(Request):
(
"forgotten_topics_data",
Array(
("topic", String),
("topic", String("utf-8")),
("partitions", Array(Int32)),
),
),
Expand Down Expand Up @@ -480,7 +480,7 @@ class FetchRequest_v11(Request):
),
(
"forgotten_topics_data",
Array(("topic", String), ("partitions", Array(Int32))),
Array(("topic", String("utf-8")), ("partitions", Array(Int32))),
),
("rack_id", String("utf-8")),
)
Expand Down
Loading
Loading