From e4b20ab5900690e4854ec320a7de6934f3b4222f Mon Sep 17 00:00:00 2001 From: Dmitriy Date: Sat, 13 Apr 2024 23:28:26 +0500 Subject: [PATCH] add typing to aiokafka/protocol/* --- Makefile | 1 + aiokafka/protocol/abstract.py | 12 +- aiokafka/protocol/admin.py | 126 +++++++++--------- aiokafka/protocol/api.py | 87 +++++++++---- aiokafka/protocol/commit.py | 20 +-- aiokafka/protocol/coordination.py | 4 +- aiokafka/protocol/fetch.py | 14 +- aiokafka/protocol/group.py | 20 +-- aiokafka/protocol/message.py | 159 ++++++++++++++++------- aiokafka/protocol/metadata.py | 34 ++--- aiokafka/protocol/offset.py | 8 +- aiokafka/protocol/produce.py | 55 ++++---- aiokafka/protocol/struct.py | 24 ++-- aiokafka/protocol/transaction.py | 20 +-- aiokafka/protocol/types.py | 207 ++++++++++++++++++------------ tests/test_protocol.py | 16 +-- 16 files changed, 493 insertions(+), 314 deletions(-) diff --git a/Makefile b/Makefile index ec298661..ca394ed2 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ FORMATTED_AREAS=\ aiokafka/helpers.py \ aiokafka/structs.py \ aiokafka/util.py \ + aiokafka/protocol/ \ tests/test_codec.py \ tests/test_helpers.py diff --git a/aiokafka/protocol/abstract.py b/aiokafka/protocol/abstract.py index 117d058e..c466357e 100644 --- a/aiokafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -1,15 +1,19 @@ import abc +from io import BytesIO +from typing import Generic, TypeVar +T = TypeVar("T") -class AbstractType(metaclass=abc.ABCMeta): + +class AbstractType(Generic[T], metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def encode(cls, value): ... + def encode(cls, value: T) -> bytes: ... @classmethod @abc.abstractmethod - def decode(cls, data): ... + def decode(cls, data: BytesIO) -> T: ... @classmethod - def repr(cls, value): + def repr(cls, value: T) -> str: return repr(value) diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py index 2bb17eeb..ace55553 100644 --- a/aiokafka/protocol/admin.py +++ b/aiokafka/protocol/admin.py @@ -1,3 +1,6 @@ +from collections.abc import Iterable +from typing import Dict, Optional, Tuple + from .api import Request, Response from .types import ( Array, @@ -68,16 +71,16 @@ class ApiVersionRequest_v2(Request): SCHEMA = ApiVersionRequest_v0.SCHEMA -ApiVersionRequest = [ +ApiVersionRequest = ( ApiVersionRequest_v0, ApiVersionRequest_v1, ApiVersionRequest_v2, -] -ApiVersionResponse = [ +) +ApiVersionResponse = ( ApiVersionResponse_v0, ApiVersionResponse_v1, ApiVersionResponse_v2, -] +) class CreateTopicsResponse_v0(Response): @@ -196,18 +199,18 @@ class CreateTopicsRequest_v3(Request): SCHEMA = CreateTopicsRequest_v1.SCHEMA -CreateTopicsRequest = [ +CreateTopicsRequest = ( CreateTopicsRequest_v0, CreateTopicsRequest_v1, CreateTopicsRequest_v2, CreateTopicsRequest_v3, -] -CreateTopicsResponse = [ +) +CreateTopicsResponse = ( CreateTopicsResponse_v0, CreateTopicsResponse_v1, CreateTopicsResponse_v2, CreateTopicsResponse_v3, -] +) class DeleteTopicsResponse_v0(Response): @@ -267,18 +270,18 @@ class DeleteTopicsRequest_v3(Request): SCHEMA = DeleteTopicsRequest_v0.SCHEMA -DeleteTopicsRequest = [ +DeleteTopicsRequest = ( DeleteTopicsRequest_v0, DeleteTopicsRequest_v1, DeleteTopicsRequest_v2, DeleteTopicsRequest_v3, -] -DeleteTopicsResponse = [ +) +DeleteTopicsResponse = ( DeleteTopicsResponse_v0, DeleteTopicsResponse_v1, DeleteTopicsResponse_v2, DeleteTopicsResponse_v3, -] +) class ListGroupsResponse_v0(Response): @@ -333,16 +336,16 @@ class ListGroupsRequest_v2(Request): SCHEMA = ListGroupsRequest_v0.SCHEMA -ListGroupsRequest = [ +ListGroupsRequest = ( ListGroupsRequest_v0, ListGroupsRequest_v1, ListGroupsRequest_v2, -] -ListGroupsResponse = [ +) +ListGroupsResponse = ( ListGroupsResponse_v0, ListGroupsResponse_v1, ListGroupsResponse_v2, -] +) class DescribeGroupsResponse_v0(Response): @@ -429,8 +432,8 @@ class DescribeGroupsResponse_v3(Response): ("member_assignment", Bytes), ), ), + ("authorized_operations", Int32), ), - ("authorized_operations", Int32), ), ) @@ -465,18 +468,18 @@ class DescribeGroupsRequest_v3(Request): ) -DescribeGroupsRequest = [ +DescribeGroupsRequest = ( DescribeGroupsRequest_v0, DescribeGroupsRequest_v1, DescribeGroupsRequest_v2, DescribeGroupsRequest_v3, -] -DescribeGroupsResponse = [ +) +DescribeGroupsResponse = ( DescribeGroupsResponse_v0, DescribeGroupsResponse_v1, DescribeGroupsResponse_v2, DescribeGroupsResponse_v3, -] +) class SaslHandShakeResponse_v0(Response): @@ -507,8 +510,8 @@ class SaslHandShakeRequest_v1(Request): SCHEMA = SaslHandShakeRequest_v0.SCHEMA -SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1] -SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] +SaslHandShakeRequest = (SaslHandShakeRequest_v0, SaslHandShakeRequest_v1) +SaslHandShakeResponse = (SaslHandShakeResponse_v0, SaslHandShakeResponse_v1) class DescribeAclsResponse_v0(Response): @@ -610,8 +613,8 @@ class DescribeAclsRequest_v2(Request): SCHEMA = DescribeAclsRequest_v1.SCHEMA -DescribeAclsRequest = [DescribeAclsRequest_v0, DescribeAclsRequest_v1] -DescribeAclsResponse = [DescribeAclsResponse_v0, DescribeAclsResponse_v1] +DescribeAclsRequest = (DescribeAclsRequest_v0, DescribeAclsRequest_v1) +DescribeAclsResponse = (DescribeAclsResponse_v0, DescribeAclsResponse_v1) class CreateAclsResponse_v0(Response): @@ -671,8 +674,8 @@ class CreateAclsRequest_v1(Request): ) -CreateAclsRequest = [CreateAclsRequest_v0, CreateAclsRequest_v1] -CreateAclsResponse = [CreateAclsResponse_v0, CreateAclsResponse_v1] +CreateAclsRequest = (CreateAclsRequest_v0, CreateAclsRequest_v1) +CreateAclsResponse = (CreateAclsResponse_v0, CreateAclsResponse_v1) class DeleteAclsResponse_v0(Response): @@ -771,8 +774,8 @@ class DeleteAclsRequest_v1(Request): ) -DeleteAclsRequest = [DeleteAclsRequest_v0, DeleteAclsRequest_v1] -DeleteAclsResponse = [DeleteAclsResponse_v0, DeleteAclsResponse_v1] +DeleteAclsRequest = (DeleteAclsRequest_v0, DeleteAclsRequest_v1) +DeleteAclsResponse = (DeleteAclsResponse_v0, DeleteAclsResponse_v1) class AlterConfigsResponse_v0(Response): @@ -828,8 +831,8 @@ class AlterConfigsRequest_v1(Request): SCHEMA = AlterConfigsRequest_v0.SCHEMA -AlterConfigsRequest = [AlterConfigsRequest_v0, AlterConfigsRequest_v1] -AlterConfigsResponse = [AlterConfigsResponse_v0, AlterConfigsRequest_v1] +AlterConfigsRequest = (AlterConfigsRequest_v0, AlterConfigsRequest_v1) +AlterConfigsResponse = (AlterConfigsResponse_v0, AlterConfigsRequest_v1) class DescribeConfigsResponse_v0(Response): @@ -969,16 +972,16 @@ class DescribeConfigsRequest_v2(Request): SCHEMA = DescribeConfigsRequest_v1.SCHEMA -DescribeConfigsRequest = [ +DescribeConfigsRequest = ( DescribeConfigsRequest_v0, DescribeConfigsRequest_v1, DescribeConfigsRequest_v2, -] -DescribeConfigsResponse = [ +) +DescribeConfigsResponse = ( DescribeConfigsResponse_v0, DescribeConfigsResponse_v1, DescribeConfigsResponse_v2, -] +) class SaslAuthenticateResponse_v0(Response): @@ -1016,14 +1019,14 @@ class SaslAuthenticateRequest_v1(Request): SCHEMA = SaslAuthenticateRequest_v0.SCHEMA -SaslAuthenticateRequest = [ +SaslAuthenticateRequest = ( SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1, -] -SaslAuthenticateResponse = [ +) +SaslAuthenticateResponse = ( SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1, -] +) class CreatePartitionsResponse_v0(Response): @@ -1075,14 +1078,14 @@ class CreatePartitionsRequest_v1(Request): RESPONSE_TYPE = CreatePartitionsResponse_v1 -CreatePartitionsRequest = [ +CreatePartitionsRequest = ( CreatePartitionsRequest_v0, CreatePartitionsRequest_v1, -] -CreatePartitionsResponse = [ +) +CreatePartitionsResponse = ( CreatePartitionsResponse_v0, CreatePartitionsResponse_v1, -] +) class DeleteGroupsResponse_v0(Response): @@ -1114,12 +1117,12 @@ class DeleteGroupsRequest_v1(Request): SCHEMA = DeleteGroupsRequest_v0.SCHEMA -DeleteGroupsRequest = [DeleteGroupsRequest_v0, DeleteGroupsRequest_v1] +DeleteGroupsRequest = (DeleteGroupsRequest_v0, DeleteGroupsRequest_v1) -DeleteGroupsResponse = [DeleteGroupsResponse_v0, DeleteGroupsResponse_v1] +DeleteGroupsResponse = (DeleteGroupsResponse_v0, DeleteGroupsResponse_v1) -class DescribeClientQuotasResponse_v0(Request): +class DescribeClientQuotasResponse_v0(Response): API_KEY = 48 API_VERSION = 0 SCHEMA = Schema( @@ -1159,13 +1162,9 @@ class DescribeClientQuotasRequest_v0(Request): ) -DescribeClientQuotasRequest = [ - DescribeClientQuotasRequest_v0, -] +DescribeClientQuotasRequest = (DescribeClientQuotasRequest_v0,) -DescribeClientQuotasResponse = [ - DescribeClientQuotasResponse_v0, -] +DescribeClientQuotasResponse = (DescribeClientQuotasResponse_v0,) class AlterPartitionReassignmentsResponse_v0(Response): @@ -1221,9 +1220,9 @@ class AlterPartitionReassignmentsRequest_v0(Request): ) -AlterPartitionReassignmentsRequest = [AlterPartitionReassignmentsRequest_v0] +AlterPartitionReassignmentsRequest = (AlterPartitionReassignmentsRequest_v0,) -AlterPartitionReassignmentsResponse = [AlterPartitionReassignmentsResponse_v0] +AlterPartitionReassignmentsResponse = (AlterPartitionReassignmentsResponse_v0,) class ListPartitionReassignmentsResponse_v0(Response): @@ -1273,9 +1272,9 @@ class ListPartitionReassignmentsRequest_v0(Request): ) -ListPartitionReassignmentsRequest = [ListPartitionReassignmentsRequest_v0] +ListPartitionReassignmentsRequest = (ListPartitionReassignmentsRequest_v0,) -ListPartitionReassignmentsResponse = [ListPartitionReassignmentsResponse_v0] +ListPartitionReassignmentsResponse = (ListPartitionReassignmentsResponse_v0,) class DeleteRecordsResponse_v0(Response): @@ -1385,7 +1384,12 @@ class DeleteRecordsRequest_v2(Request): ("tags", TaggedFields), ) - def __init__(self, topics, timeout_ms, tags=None): + def __init__( + self, + topics: Iterable[Tuple[str, Iterable[Tuple[int, int]]]], + timeout_ms: int, + tags: Optional[Dict[int, bytes]] = None, + ) -> None: super().__init__( [ ( @@ -1403,14 +1407,14 @@ def __init__(self, topics, timeout_ms, tags=None): ) -DeleteRecordsRequest = [ +DeleteRecordsRequest = ( DeleteRecordsRequest_v0, DeleteRecordsRequest_v1, DeleteRecordsRequest_v2, -] +) -DeleteRecordsResponse = [ +DeleteRecordsResponse = ( DeleteRecordsResponse_v0, DeleteRecordsResponse_v1, DeleteRecordsResponse_v2, -] +) diff --git a/aiokafka/protocol/api.py b/aiokafka/protocol/api.py index 77a7a485..bc5b2fc9 100644 --- a/aiokafka/protocol/api.py +++ b/aiokafka/protocol/api.py @@ -1,4 +1,6 @@ import abc +from io import BytesIO +from typing import Any, ClassVar, Dict, Optional, Type, Union from .struct import Struct from .types import Array, Int16, Int32, Schema, String, TaggedFields @@ -12,7 +14,9 @@ class RequestHeader_v0(Struct): ("client_id", String("utf-8")), ) - def __init__(self, request, correlation_id=0, client_id="aiokafka"): + def __init__( + self, request: "Request", correlation_id: int = 0, client_id: str = "aiokafka" + ) -> None: super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id ) @@ -28,7 +32,13 @@ class RequestHeader_v1(Struct): ("tags", TaggedFields), ) - def __init__(self, request, correlation_id=0, client_id="aiokafka", tags=None): + def __init__( + self, + request: "Request", + correlation_id: int = 0, + client_id: str = "aiokafka", + tags: Optional[Dict[int, bytes]] = None, + ): super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} ) @@ -48,32 +58,46 @@ class ResponseHeader_v1(Struct): class Request(Struct, metaclass=abc.ABCMeta): - FLEXIBLE_VERSION = False + FLEXIBLE_VERSION: ClassVar[bool] = False - @abc.abstractproperty - def API_KEY(self): + @property + @abc.abstractmethod + def API_KEY(self) -> int: # pyright:ignore[reportRedeclaration] """Integer identifier for api request""" - @abc.abstractproperty - def API_VERSION(self): + API_KEY: ClassVar[int] # type: ignore[no-redef] + + @property + @abc.abstractmethod + def API_VERSION(self) -> int: # pyright:ignore[reportRedeclaration] """Integer of api request version""" - @abc.abstractproperty - def SCHEMA(self): - """An instance of Schema() representing the request structure""" + API_VERSION: ClassVar[int] # type: ignore[no-redef] - @abc.abstractproperty - def RESPONSE_TYPE(self): + @property + @abc.abstractmethod + def RESPONSE_TYPE(self) -> Type["Response"]: # pyright:ignore[reportRedeclaration] """The Response class associated with the api request""" - def expect_response(self): + RESPONSE_TYPE: ClassVar[Type["Response"]] # type: ignore[no-redef] + + @property + @abc.abstractmethod + def SCHEMA(self) -> Schema: # pyright:ignore[reportRedeclaration] + """An instance of Schema() representing the request structure""" + + SCHEMA: ClassVar[Schema] # type: ignore[no-redef] + + def expect_response(self) -> bool: """Override this method if an api request does not always generate a response""" return True - def to_object(self): + def to_object(self) -> Dict[str, Any]: return _to_object(self.SCHEMA, self) - def build_request_header(self, correlation_id, client_id): + def build_request_header( + self, correlation_id: int, client_id: str + ) -> Union[RequestHeader_v0, RequestHeader_v1]: if self.FLEXIBLE_VERSION: return RequestHeader_v1( self, correlation_id=correlation_id, client_id=client_id @@ -82,31 +106,42 @@ def build_request_header(self, correlation_id, client_id): self, correlation_id=correlation_id, client_id=client_id ) - def parse_response_header(self, read_buffer): + def parse_response_header( + self, read_buffer: Union[BytesIO, bytes] + ) -> Union[ResponseHeader_v0, ResponseHeader_v1]: if self.FLEXIBLE_VERSION: return ResponseHeader_v1.decode(read_buffer) return ResponseHeader_v0.decode(read_buffer) class Response(Struct, metaclass=abc.ABCMeta): - @abc.abstractproperty - def API_KEY(self): + @property + @abc.abstractmethod + def API_KEY(self) -> int: # pyright:ignore[reportRedeclaration] """Integer identifier for api request/response""" - @abc.abstractproperty - def API_VERSION(self): + API_KEY: ClassVar[int] # type: ignore[no-redef] + + @property + @abc.abstractmethod + def API_VERSION(self) -> int: # pyright:ignore[reportRedeclaration] """Integer of api request/response version""" - @abc.abstractproperty - def SCHEMA(self): + API_VERSION: ClassVar[int] # type: ignore[no-redef] + + @property + @abc.abstractmethod + def SCHEMA(self) -> Schema: # pyright:ignore[reportRedeclaration] """An instance of Schema() representing the response structure""" - def to_object(self): + SCHEMA: ClassVar[Schema] # type: ignore[no-redef] + + def to_object(self) -> Dict[str, Any]: return _to_object(self.SCHEMA, self) -def _to_object(schema, data): - obj = {} +def _to_object(schema: Schema, data: Union[Struct, Dict[int, Any]]) -> Dict[str, Any]: + obj: Dict[str, Any] = {} for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)): if isinstance(data, Struct): val = data.get_item(name) @@ -116,7 +151,7 @@ def _to_object(schema, data): if isinstance(_type, Schema): obj[name] = _to_object(_type, val) elif isinstance(_type, Array): - if isinstance(_type.array_of, (Array, Schema)): + if isinstance(_type.array_of, Schema): obj[name] = [_to_object(_type.array_of, x) for x in val] else: obj[name] = val diff --git a/aiokafka/protocol/commit.py b/aiokafka/protocol/commit.py index b0fda8c3..6ba69f7a 100644 --- a/aiokafka/protocol/commit.py +++ b/aiokafka/protocol/commit.py @@ -127,18 +127,18 @@ class OffsetCommitRequest_v3(Request): SCHEMA = OffsetCommitRequest_v2.SCHEMA -OffsetCommitRequest = [ +OffsetCommitRequest = ( OffsetCommitRequest_v0, OffsetCommitRequest_v1, OffsetCommitRequest_v2, OffsetCommitRequest_v3, -] -OffsetCommitResponse = [ +) +OffsetCommitResponse = ( OffsetCommitResponse_v0, OffsetCommitResponse_v1, OffsetCommitResponse_v2, OffsetCommitResponse_v3, -] +) class OffsetFetchResponse_v0(Response): @@ -251,18 +251,18 @@ class OffsetFetchRequest_v3(Request): SCHEMA = OffsetFetchRequest_v2.SCHEMA -OffsetFetchRequest = [ +OffsetFetchRequest = ( OffsetFetchRequest_v0, OffsetFetchRequest_v1, OffsetFetchRequest_v2, OffsetFetchRequest_v3, -] -OffsetFetchResponse = [ +) +OffsetFetchResponse = ( OffsetFetchResponse_v0, OffsetFetchResponse_v1, OffsetFetchResponse_v2, OffsetFetchResponse_v3, -] +) class GroupCoordinatorResponse_v0(Response): @@ -308,5 +308,5 @@ class GroupCoordinatorRequest_v1(Request): ) -GroupCoordinatorRequest = [GroupCoordinatorRequest_v0, GroupCoordinatorRequest_v1] -GroupCoordinatorResponse = [GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1] +GroupCoordinatorRequest = (GroupCoordinatorRequest_v0, GroupCoordinatorRequest_v1) +GroupCoordinatorResponse = (GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1) diff --git a/aiokafka/protocol/coordination.py b/aiokafka/protocol/coordination.py index 9bf086ac..3e0a9088 100644 --- a/aiokafka/protocol/coordination.py +++ b/aiokafka/protocol/coordination.py @@ -40,5 +40,5 @@ class FindCoordinatorRequest_v1(Request): SCHEMA = Schema(("coordinator_key", String("utf-8")), ("coordinator_type", Int8)) -FindCoordinatorRequest = [FindCoordinatorRequest_v0, FindCoordinatorRequest_v1] -FindCoordinatorResponse = [FindCoordinatorResponse_v0, FindCoordinatorResponse_v1] +FindCoordinatorRequest = (FindCoordinatorRequest_v0, FindCoordinatorRequest_v1) +FindCoordinatorResponse = (FindCoordinatorResponse_v0, FindCoordinatorResponse_v1) diff --git a/aiokafka/protocol/fetch.py b/aiokafka/protocol/fetch.py index 56cbdd73..ccbddceb 100644 --- a/aiokafka/protocol/fetch.py +++ b/aiokafka/protocol/fetch.py @@ -376,7 +376,7 @@ class FetchRequest_v7(Request): ), ( "forgotten_topics_data", - Array(("topic", String), ("partitions", Array(Int32))), + Array(("topic", String("utf-8")), ("partitions", Array(Int32))), ), ) @@ -428,7 +428,7 @@ class FetchRequest_v9(Request): ( "forgotten_topics_data", Array( - ("topic", String), + ("topic", String("utf-8")), ("partitions", Array(Int32)), ), ), @@ -480,13 +480,13 @@ class FetchRequest_v11(Request): ), ( "forgotten_topics_data", - Array(("topic", String), ("partitions", Array(Int32))), + Array(("topic", String("utf-8")), ("partitions", Array(Int32))), ), ("rack_id", String("utf-8")), ) -FetchRequest = [ +FetchRequest = ( FetchRequest_v0, FetchRequest_v1, FetchRequest_v2, @@ -499,8 +499,8 @@ class FetchRequest_v11(Request): FetchRequest_v9, FetchRequest_v10, FetchRequest_v11, -] -FetchResponse = [ +) +FetchResponse = ( FetchResponse_v0, FetchResponse_v1, FetchResponse_v2, @@ -513,4 +513,4 @@ class FetchRequest_v11(Request): FetchResponse_v9, FetchResponse_v10, FetchResponse_v11, -] +) diff --git a/aiokafka/protocol/group.py b/aiokafka/protocol/group.py index 9e4efb41..a31ec04d 100644 --- a/aiokafka/protocol/group.py +++ b/aiokafka/protocol/group.py @@ -119,18 +119,18 @@ class JoinGroupRequest_v5(Request): UNKNOWN_MEMBER_ID = "" -JoinGroupRequest = [ +JoinGroupRequest = ( JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2, JoinGroupRequest_v5, -] -JoinGroupResponse = [ +) +JoinGroupResponse = ( JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2, JoinGroupResponse_v5, -] +) class ProtocolMetadata(Struct): @@ -199,8 +199,8 @@ class SyncGroupRequest_v3(Request): ) -SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1, SyncGroupRequest_v3] -SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1, SyncGroupResponse_v3] +SyncGroupRequest = (SyncGroupRequest_v0, SyncGroupRequest_v1, SyncGroupRequest_v3) +SyncGroupResponse = (SyncGroupResponse_v0, SyncGroupResponse_v1, SyncGroupResponse_v3) class MemberAssignment(Struct): @@ -241,8 +241,8 @@ class HeartbeatRequest_v1(Request): SCHEMA = HeartbeatRequest_v0.SCHEMA -HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1] -HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1] +HeartbeatRequest = (HeartbeatRequest_v0, HeartbeatRequest_v1) +HeartbeatResponse = (HeartbeatResponse_v0, HeartbeatResponse_v1) class LeaveGroupResponse_v0(Response): @@ -271,5 +271,5 @@ class LeaveGroupRequest_v1(Request): SCHEMA = LeaveGroupRequest_v0.SCHEMA -LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1] -LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1] +LeaveGroupRequest = (LeaveGroupRequest_v0, LeaveGroupRequest_v1) +LeaveGroupResponse = (LeaveGroupResponse_v0, LeaveGroupResponse_v1) diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index 31993fe6..42b81078 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,6 +1,10 @@ import io import time from binascii import crc32 +from collections.abc import Iterable +from typing import List, Literal, Optional, Tuple, Union, cast, overload + +from typing_extensions import Self from aiokafka.codec import ( gzip_decode, @@ -15,7 +19,7 @@ from aiokafka.errors import UnsupportedCodecError from .struct import Struct -from .types import AbstractType, Bytes, Int8, Int32, Int64, Schema, UInt32 +from .types import Bytes, Int8, Int32, Int64, Schema, Type, UInt32 class Message(Struct): @@ -47,7 +51,40 @@ class Message(Struct): 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) ) - def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None): + @overload + def __init__( + self, + *, + value: Optional[bytes], + key: Optional[bytes] = None, + magic: Literal[0], + attributes: int = 0, + crc: int = 0, + timestamp: None = None, + ) -> None: ... + + @overload + def __init__( + self, + *, + value: Optional[bytes], + key: Optional[bytes] = None, + magic: Literal[1], + attributes: int = 0, + crc: int = 0, + timestamp: Optional[int] = None, + ) -> None: ... + + def __init__( + self, + *, + value: Optional[bytes], + key: Optional[bytes] = None, + magic: Literal[0, 1] = 0, + attributes: int = 0, + crc: int = 0, + timestamp: Optional[int] = None, + ) -> None: assert value is None or isinstance(value, bytes), "value must be bytes" assert key is None or isinstance(key, bytes), "key must be bytes" assert magic > 0 or timestamp is None, "timestamp not supported in v0" @@ -57,14 +94,14 @@ def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None timestamp = int(time.time() * 1000) self.timestamp = timestamp self.crc = crc - self._validated_crc = None + self._validated_crc: Optional[int] = None self.magic = magic self.attributes = attributes self.key = key self.value = value @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[Literal[0, 1]]: """0 for CreateTime; 1 for LogAppendTime; None if unsupported. Value is determined by broker; produced messages should always set to 0 @@ -77,55 +114,71 @@ def timestamp_type(self): else: return 0 - def encode(self, recalc_crc=True): + def encode(self, recalc_crc: bool = True) -> bytes: version = self.magic if version == 1: - fields = ( - self.crc, - self.magic, - self.attributes, - self.timestamp, - self.key, - self.value, + message = Message.SCHEMAS[version].encode( + ( + self.crc, + self.magic, + self.attributes, + self.timestamp, + self.key, + self.value, + ) ) elif version == 0: - fields = (self.crc, self.magic, self.attributes, self.key, self.value) + message = Message.SCHEMAS[version].encode( + (self.crc, self.magic, self.attributes, self.key, self.value) + ) else: raise ValueError(f"Unrecognized message version: {version}") - message = Message.SCHEMAS[version].encode(fields) if not recalc_crc: return message self.crc = crc32(message[4:]) - crc_field = self.SCHEMAS[version].fields[0] + crc_field = cast(Type[UInt32], self.SCHEMAS[version].fields[0]) return crc_field.encode(self.crc) + message[4:] @classmethod - def decode(cls, data): - _validated_crc = None + def decode(cls, data: Union[io.BytesIO, bytes]) -> Self: + _validated_crc: Optional[int] = None if isinstance(data, bytes): _validated_crc = crc32(data[4:]) data = io.BytesIO(data) # Partial decode required to determine message version - base_fields = cls.SCHEMAS[0].fields[0:3] + base_fields = cast( + Tuple[Type[UInt32], Type[Int8], Type[Int8]], cls.SCHEMAS[0].fields[0:3] + ) crc, magic, attributes = (field.decode(data) for field in base_fields) remaining = cls.SCHEMAS[magic].fields[3:] - fields = [field.decode(data) for field in remaining] + fields = tuple(field.decode(data) for field in remaining) + magic = cast(Literal[0, 1], magic) if magic == 1: - timestamp = fields[0] + fields = cast(Tuple[int, Optional[bytes], Optional[bytes]], fields) + msg = cls( + value=fields[-1], + key=fields[-2], + magic=magic, + attributes=attributes, + crc=crc, + timestamp=fields[0], + ) + elif magic == 0: + fields = cast(Tuple[Optional[bytes], Optional[bytes]], fields) + msg = cls( + value=fields[-1], + key=fields[-2], + magic=magic, + attributes=attributes, + crc=crc, + ) else: - timestamp = None - msg = cls( - fields[-1], - key=fields[-2], - magic=magic, - attributes=attributes, - crc=crc, - timestamp=timestamp, - ) + raise ValueError(f"Unrecognized message version: {magic}") + msg._validated_crc = _validated_crc return msg - def validate_crc(self): + def validate_crc(self) -> bool: if self._validated_crc is None: raw_msg = self.encode(recalc_crc=False) self._validated_crc = crc32(raw_msg[4:]) @@ -133,10 +186,13 @@ def validate_crc(self): return True return False - def is_compressed(self): + def is_compressed(self) -> bool: return self.attributes & self.CODEC_MASK != 0 - def decompress(self): + def decompress( + self, + ) -> List[Union[Tuple[int, int, "Message"], Tuple[None, None, "PartialMessage"]]]: + assert self.value is not None codec = self.attributes & self.CODEC_MASK assert codec in ( self.CODEC_GZIP, @@ -167,21 +223,25 @@ def decompress(self): return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes)) - def __hash__(self): + def __hash__(self) -> int: return hash(self.encode(recalc_crc=False)) class PartialMessage(bytes): - def __repr__(self): - return f"PartialMessage({self})" + def __repr__(self) -> str: + return f"PartialMessage({self!r})" -class MessageSet(AbstractType): +class MessageSet: ITEM = Schema(("offset", Int64), ("message", Bytes)) HEADER_SIZE = 12 # offset + message_size @classmethod - def encode(cls, items, prepend_size=True): + def encode( + cls, + items: Union[io.BytesIO, Iterable[Tuple[int, bytes]]], + prepend_size: bool = True, + ) -> bytes: # RecordAccumulator encodes messagesets internally if isinstance(items, io.BytesIO): size = Int32.decode(items) @@ -191,7 +251,7 @@ def encode(cls, items, prepend_size=True): size += 4 return items.read(size) - encoded_values = [] + encoded_values: List[bytes] = [] for offset, message in items: encoded_values.append(Int64.encode(offset)) encoded_values.append(Bytes.encode(message)) @@ -202,7 +262,9 @@ def encode(cls, items, prepend_size=True): return encoded @classmethod - def decode(cls, data, bytes_to_read=None): + def decode( + cls, data: Union[io.BytesIO, bytes], bytes_to_read: Optional[int] = None + ) -> List[Union[Tuple[int, int, Message], Tuple[None, None, PartialMessage]]]: """Compressed messages should pass in bytes_to_read (via message size) otherwise, we decode from data as Int32 """ @@ -216,11 +278,14 @@ def decode(cls, data, bytes_to_read=None): # So create an internal buffer to avoid over-reading raw = io.BytesIO(data.read(bytes_to_read)) - items = [] + items: List[ + Union[Tuple[int, int, Message], Tuple[None, None, PartialMessage]] + ] = [] try: while bytes_to_read: offset = Int64.decode(raw) msg_bytes = Bytes.decode(raw) + assert msg_bytes is not None bytes_to_read -= 8 + 4 + len(msg_bytes) items.append( (offset, len(msg_bytes), Message.decode(msg_bytes)), @@ -233,10 +298,18 @@ def decode(cls, data, bytes_to_read=None): return items @classmethod - def repr(cls, messages): + def repr( + cls, + messages: Union[ + io.BytesIO, + List[Union[Tuple[int, int, Message], Tuple[None, None, PartialMessage]]], + ], + ) -> str: if isinstance(messages, io.BytesIO): offset = messages.tell() decoded = cls.decode(messages) messages.seek(offset) - messages = decoded - return str([cls.ITEM.repr(m) for m in messages]) + decoded_messages = decoded + else: + decoded_messages = messages + return str([cls.ITEM.repr(m) for m in decoded_messages]) diff --git a/aiokafka/protocol/metadata.py b/aiokafka/protocol/metadata.py index 79a5600a..202fc2bb 100644 --- a/aiokafka/protocol/metadata.py +++ b/aiokafka/protocol/metadata.py @@ -1,3 +1,5 @@ +from typing import Optional + from .api import Request, Response from .types import Array, Boolean, Int16, Int32, Schema, String @@ -187,7 +189,9 @@ class MetadataRequest_v0(Request): API_VERSION = 0 RESPONSE_TYPE = MetadataResponse_v0 SCHEMA = Schema(("topics", Array(String("utf-8")))) - ALL_TOPICS = None # Empty Array (len 0) for topics returns all topics + ALL_TOPICS: Optional[int] = ( + None # Empty Array (len 0) for topics returns all topics + ) class MetadataRequest_v1(Request): @@ -195,8 +199,8 @@ class MetadataRequest_v1(Request): API_VERSION = 1 RESPONSE_TYPE = MetadataResponse_v1 SCHEMA = MetadataRequest_v0.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics + ALL_TOPICS: Optional[int] = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS: None = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v2(Request): @@ -204,8 +208,8 @@ class MetadataRequest_v2(Request): API_VERSION = 2 RESPONSE_TYPE = MetadataResponse_v2 SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics + ALL_TOPICS: Optional[int] = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS: None = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v3(Request): @@ -213,8 +217,8 @@ class MetadataRequest_v3(Request): API_VERSION = 3 RESPONSE_TYPE = MetadataResponse_v3 SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics + ALL_TOPICS: Optional[int] = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS: None = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v4(Request): @@ -224,8 +228,8 @@ class MetadataRequest_v4(Request): SCHEMA = Schema( ("topics", Array(String("utf-8"))), ("allow_auto_topic_creation", Boolean) ) - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics + ALL_TOPICS: Optional[int] = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS: None = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v5(Request): @@ -238,23 +242,23 @@ class MetadataRequest_v5(Request): API_VERSION = 5 RESPONSE_TYPE = MetadataResponse_v5 SCHEMA = MetadataRequest_v4.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics + ALL_TOPICS: Optional[int] = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS: None = None # Empty array (len 0) for topics returns no topics -MetadataRequest = [ +MetadataRequest = ( MetadataRequest_v0, MetadataRequest_v1, MetadataRequest_v2, MetadataRequest_v3, MetadataRequest_v4, MetadataRequest_v5, -] -MetadataResponse = [ +) +MetadataResponse = ( MetadataResponse_v0, MetadataResponse_v1, MetadataResponse_v2, MetadataResponse_v3, MetadataResponse_v4, MetadataResponse_v5, -] +) diff --git a/aiokafka/protocol/offset.py b/aiokafka/protocol/offset.py index 3d8d659a..7a06bd3b 100644 --- a/aiokafka/protocol/offset.py +++ b/aiokafka/protocol/offset.py @@ -228,19 +228,19 @@ class OffsetRequest_v5(Request): DEFAULTS = {"replica_id": -1} -OffsetRequest = [ +OffsetRequest = ( OffsetRequest_v0, OffsetRequest_v1, OffsetRequest_v2, OffsetRequest_v3, OffsetRequest_v4, OffsetRequest_v5, -] -OffsetResponse = [ +) +OffsetResponse = ( OffsetResponse_v0, OffsetResponse_v1, OffsetResponse_v2, OffsetResponse_v3, OffsetResponse_v4, OffsetResponse_v5, -] +) diff --git a/aiokafka/protocol/produce.py b/aiokafka/protocol/produce.py index e55f616a..181ffe2f 100644 --- a/aiokafka/protocol/produce.py +++ b/aiokafka/protocol/produce.py @@ -148,17 +148,15 @@ class ProduceResponse_v8(Response): ("offset", Int64), ("timestamp", Int64), ("log_start_offset", Int64), - ), - ( - "record_errors", ( + "record_errors", Array( ("batch_index", Int32), ("batch_index_error_message", String("utf-8")), - ) + ), ), + ("error_message", String("utf-8")), ), - ("error_message", String("utf-8")), ), ), ), @@ -166,16 +164,11 @@ class ProduceResponse_v8(Response): ) -class ProduceRequest(Request): +class ProduceRequestBase(Request): API_KEY = 0 - def expect_response(self): - if self.required_acks == 0: - return False - return True - -class ProduceRequest_v0(ProduceRequest): +class ProduceRequest_v0(ProduceRequestBase): API_VERSION = 0 RESPONSE_TYPE = ProduceResponse_v0 SCHEMA = Schema( @@ -190,20 +183,27 @@ class ProduceRequest_v0(ProduceRequest): ), ) + required_acks: int -class ProduceRequest_v1(ProduceRequest): + def expect_response(self) -> bool: + if self.required_acks == 0: + return False + return True + + +class ProduceRequest_v1(ProduceRequestBase): API_VERSION = 1 RESPONSE_TYPE = ProduceResponse_v1 SCHEMA = ProduceRequest_v0.SCHEMA -class ProduceRequest_v2(ProduceRequest): +class ProduceRequest_v2(ProduceRequestBase): API_VERSION = 2 RESPONSE_TYPE = ProduceResponse_v2 SCHEMA = ProduceRequest_v1.SCHEMA -class ProduceRequest_v3(ProduceRequest): +class ProduceRequest_v3(ProduceRequestBase): API_VERSION = 3 RESPONSE_TYPE = ProduceResponse_v3 SCHEMA = Schema( @@ -219,8 +219,15 @@ class ProduceRequest_v3(ProduceRequest): ), ) + required_acks: int + + def expect_response(self) -> bool: + if self.required_acks == 0: + return False + return True + -class ProduceRequest_v4(ProduceRequest): +class ProduceRequest_v4(ProduceRequestBase): """ The version number is bumped up to indicate that the client supports KafkaStorageException. The KafkaStorageException will be translated to @@ -232,7 +239,7 @@ class ProduceRequest_v4(ProduceRequest): SCHEMA = ProduceRequest_v3.SCHEMA -class ProduceRequest_v5(ProduceRequest): +class ProduceRequest_v5(ProduceRequestBase): """ Same as v4. The version number is bumped since the v5 response includes an additional partition level field: the log_start_offset. @@ -243,7 +250,7 @@ class ProduceRequest_v5(ProduceRequest): SCHEMA = ProduceRequest_v4.SCHEMA -class ProduceRequest_v6(ProduceRequest): +class ProduceRequest_v6(ProduceRequestBase): """ The version number is bumped to indicate that on quota violation brokers send out responses before throttling. @@ -254,7 +261,7 @@ class ProduceRequest_v6(ProduceRequest): SCHEMA = ProduceRequest_v5.SCHEMA -class ProduceRequest_v7(ProduceRequest): +class ProduceRequest_v7(ProduceRequestBase): """ V7 bumped up to indicate ZStandard capability. (see KIP-110) """ @@ -264,7 +271,7 @@ class ProduceRequest_v7(ProduceRequest): SCHEMA = ProduceRequest_v6.SCHEMA -class ProduceRequest_v8(ProduceRequest): +class ProduceRequest_v8(ProduceRequestBase): """ V8 bumped up to add two new fields record_errors offset list and error_message to PartitionResponse (See KIP-467) @@ -275,7 +282,7 @@ class ProduceRequest_v8(ProduceRequest): SCHEMA = ProduceRequest_v7.SCHEMA -ProduceRequest = [ +ProduceRequest = ( ProduceRequest_v0, ProduceRequest_v1, ProduceRequest_v2, @@ -285,8 +292,8 @@ class ProduceRequest_v8(ProduceRequest): ProduceRequest_v6, ProduceRequest_v7, ProduceRequest_v8, -] -ProduceResponse = [ +) +ProduceResponse = ( ProduceResponse_v0, ProduceResponse_v1, ProduceResponse_v2, @@ -296,4 +303,4 @@ class ProduceRequest_v8(ProduceRequest): ProduceResponse_v6, ProduceResponse_v7, ProduceResponse_v8, -] +) diff --git a/aiokafka/protocol/struct.py b/aiokafka/protocol/struct.py index ee99c75a..fbccb1a8 100644 --- a/aiokafka/protocol/struct.py +++ b/aiokafka/protocol/struct.py @@ -1,13 +1,15 @@ from io import BytesIO +from typing import Any, ClassVar, List, Union + +from typing_extensions import Self -from .abstract import AbstractType from .types import Schema -class Struct(AbstractType): - SCHEMA = Schema() +class Struct: + SCHEMA: ClassVar = Schema() - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == len(self.SCHEMA.fields): for i, name in enumerate(self.SCHEMA.names): self.__dict__[name] = args[i] @@ -23,27 +25,29 @@ def __init__(self, *args, **kwargs): ) ) - def encode(self): + def encode(self) -> bytes: return self.SCHEMA.encode([self.__dict__[name] for name in self.SCHEMA.names]) @classmethod - def decode(cls, data): + def decode(cls, data: Union[BytesIO, bytes]) -> Self: if isinstance(data, bytes): data = BytesIO(data) return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) - def get_item(self, name): + def get_item(self, name: str) -> Any: if name not in self.SCHEMA.names: raise KeyError("%s is not in the schema" % name) return self.__dict__[name] - def __repr__(self): - key_vals = [] + def __repr__(self) -> str: + key_vals: List[str] = [] for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): key_vals.append(f"{name}={field.repr(self.__dict__[name])}") return self.__class__.__name__ + "(" + ", ".join(key_vals) + ")" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Struct): + return False if self.SCHEMA != other.SCHEMA: return False for attr in self.SCHEMA.names: diff --git a/aiokafka/protocol/transaction.py b/aiokafka/protocol/transaction.py index ea9734ab..3d36c9b0 100644 --- a/aiokafka/protocol/transaction.py +++ b/aiokafka/protocol/transaction.py @@ -132,19 +132,19 @@ class TxnOffsetCommitRequest_v0(Request): ) -InitProducerIdRequest = [InitProducerIdRequest_v0] -InitProducerIdResponse = [InitProducerIdResponse_v0] +InitProducerIdRequest = (InitProducerIdRequest_v0,) +InitProducerIdResponse = (InitProducerIdResponse_v0,) -AddPartitionsToTxnRequest = [AddPartitionsToTxnRequest_v0] -AddPartitionsToTxnResponse = [AddPartitionsToTxnResponse_v0] +AddPartitionsToTxnRequest = (AddPartitionsToTxnRequest_v0,) +AddPartitionsToTxnResponse = (AddPartitionsToTxnResponse_v0,) -AddOffsetsToTxnRequest = [AddOffsetsToTxnRequest_v0] -AddOffsetsToTxnResponse = [AddOffsetsToTxnResponse_v0] +AddOffsetsToTxnRequest = (AddOffsetsToTxnRequest_v0,) +AddOffsetsToTxnResponse = (AddOffsetsToTxnResponse_v0,) -EndTxnRequest = [EndTxnRequest_v0] +EndTxnRequest = (EndTxnRequest_v0,) -EndTxnResponse = [EndTxnResponse_v0] +EndTxnResponse = (EndTxnResponse_v0,) -TxnOffsetCommitResponse = [TxnOffsetCommitResponse_v0] +TxnOffsetCommitResponse = (TxnOffsetCommitResponse_v0,) -TxnOffsetCommitRequest = [TxnOffsetCommitRequest_v0] +TxnOffsetCommitRequest = (TxnOffsetCommitRequest_v0,) diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 7eadf7fb..7f43214d 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -1,10 +1,31 @@ import struct +from collections.abc import Sequence +from io import BytesIO from struct import error +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import Buffer, TypeAlias from .abstract import AbstractType +T = TypeVar("T") -def _pack(f, value): +ValueT: TypeAlias = Union[Type[AbstractType[Any]], "String", "Array", "Schema"] + + +def _pack(f: Callable[[T], bytes], value: T) -> bytes: try: return f(value) except error as e: @@ -14,7 +35,7 @@ def _pack(f, value): ) from e -def _unpack(f, data): +def _unpack(f: Callable[[Buffer], Tuple[T, ...]], data: Buffer) -> T: try: (value,) = f(data) except error as e: @@ -26,95 +47,95 @@ def _unpack(f, data): return value -class Int8(AbstractType): +class Int8(AbstractType[int]): _pack = struct.Struct(">b").pack _unpack = struct.Struct(">b").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(1)) -class Int16(AbstractType): +class Int16(AbstractType[int]): _pack = struct.Struct(">h").pack _unpack = struct.Struct(">h").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(2)) -class Int32(AbstractType): +class Int32(AbstractType[int]): _pack = struct.Struct(">i").pack _unpack = struct.Struct(">i").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) -class UInt32(AbstractType): +class UInt32(AbstractType[int]): _pack = struct.Struct(">I").pack _unpack = struct.Struct(">I").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) -class Int64(AbstractType): +class Int64(AbstractType[int]): _pack = struct.Struct(">q").pack _unpack = struct.Struct(">q").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(8)) -class Float64(AbstractType): +class Float64(AbstractType[float]): _pack = struct.Struct(">d").pack _unpack = struct.Struct(">d").unpack @classmethod - def encode(cls, value): + def encode(cls, value: float) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> float: return _unpack(cls._unpack, data.read(8)) -class String(AbstractType): - def __init__(self, encoding="utf-8"): +class String: + def __init__(self, encoding: str = "utf-8"): self.encoding = encoding - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return Int16.encode(-1) - value = str(value).encode(self.encoding) - return Int16.encode(len(value)) + value + encoded_value = str(value).encode(self.encoding) + return Int16.encode(len(encoded_value)) + encoded_value - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[str]: length = Int16.decode(data) if length < 0: return None @@ -123,17 +144,21 @@ def decode(self, data): raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) + @classmethod + def repr(cls, value: str) -> str: + return repr(value) + -class Bytes(AbstractType): +class Bytes(AbstractType[Optional[bytes]]): @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return Int32.encode(-1) else: return Int32.encode(len(value)) + value @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Optional[bytes]: length = Int32.decode(data) if length < 0: return None @@ -143,45 +168,50 @@ def decode(cls, data): return value @classmethod - def repr(cls, value): + def repr(cls, value: Optional[bytes]) -> str: return repr( value[:100] + b"..." if value is not None and len(value) > 100 else value ) -class Boolean(AbstractType): +class Boolean(AbstractType[bool]): _pack = struct.Struct(">?").pack _unpack = struct.Struct(">?").unpack @classmethod - def encode(cls, value): + def encode(cls, value: bool) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> bool: return _unpack(cls._unpack, data.read(1)) -class Schema(AbstractType): - def __init__(self, *fields): +class Schema: + names: Tuple[str, ...] + fields: Tuple[ValueT, ...] + + def __init__(self, *fields: Tuple[str, ValueT]): if fields: self.names, self.fields = zip(*fields) else: self.names, self.fields = (), () - def encode(self, item): + def encode(self, item: Sequence[Any]) -> bytes: if len(item) != len(self.fields): raise ValueError("Item field count does not match Schema") return b"".join(field.encode(item[i]) for i, field in enumerate(self.fields)) - def decode(self, data): + def decode( + self, data: BytesIO + ) -> Tuple[Union[Any, str, None, List[Union[Any, Tuple[Any, ...]]]], ...]: return tuple(field.decode(data) for field in self.fields) - def __len__(self): + def __len__(self) -> int: return len(self.fields) - def repr(self, value): - key_vals = [] + def repr(self, value: Any) -> str: + key_vals: List[str] = [] try: for i in range(len(self)): try: @@ -194,19 +224,35 @@ def repr(self, value): return repr(value) -class Array(AbstractType): - def __init__(self, *array_of): - if len(array_of) > 1: - self.array_of = Schema(*array_of) - elif len(array_of) == 1 and ( - isinstance(array_of[0], AbstractType) - or issubclass(array_of[0], AbstractType) - ): - self.array_of = array_of[0] - else: - raise ValueError("Array instantiated with no array_of type") +class Array: + array_of: ValueT - def encode(self, items): + @overload + def __init__(self, array_of_0: ValueT): ... + + @overload + def __init__( + self, array_of_0: Tuple[str, ValueT], *array_of: Tuple[str, ValueT] + ): ... + + def __init__( + self, + array_of_0: Union[ValueT, Tuple[str, ValueT]], + *array_of: Tuple[str, ValueT], + ) -> None: + if array_of: + array_of_0 = cast(Tuple[str, ValueT], array_of_0) + self.array_of = Schema(array_of_0, *array_of) + else: + array_of_0 = cast(ValueT, array_of_0) + if isinstance(array_of_0, (String, Array, Schema)) or issubclass( + array_of_0, AbstractType + ): + self.array_of = array_of_0 + else: + raise ValueError("Array instantiated with no array_of type") + + def encode(self, items: Optional[Sequence[Any]]) -> bytes: if items is None: return Int32.encode(-1) encoded_items = (self.array_of.encode(item) for item in items) @@ -214,22 +260,23 @@ def encode(self, items): (Int32.encode(len(items)), *encoded_items), ) - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[List[Union[Any, Tuple[Any, ...]]]]: length = Int32.decode(data) if length == -1: return None return [self.array_of.decode(data) for _ in range(length)] - def repr(self, list_of_items): + def repr(self, list_of_items: Optional[Sequence[Any]]) -> str: if list_of_items is None: return "NULL" return "[" + ", ".join(self.array_of.repr(item) for item in list_of_items) + "]" -class UnsignedVarInt32(AbstractType): +class UnsignedVarInt32(AbstractType[int]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: value, i = 0, 0 + b: int while True: (b,) = struct.unpack("B", data.read(1)) if not (b & 0x80): @@ -242,7 +289,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: value &= 0xFFFFFFFF ret = b"" while (value & 0xFFFFFF80) != 0: @@ -253,36 +300,36 @@ def encode(cls, value): return ret -class VarInt32(AbstractType): +class VarInt32(AbstractType[int]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: value = UnsignedVarInt32.decode(data) return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFF return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) -class VarInt64(AbstractType): +class VarInt64(AbstractType[int]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: value, i = 0, 0 while True: b = data.read(1) - if not (b & 0x80): + if not (b & 0x80): # type: ignore[operator] break - value |= (b & 0x7F) << i + value |= (b & 0x7F) << i # type: ignore[operator] i += 7 if i > 63: raise ValueError(f"Invalid value {value}") - value |= b << i + value |= b << i # type: ignore[operator] return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFFFFFFFFFF v = (value << 1) ^ (value >> 63) @@ -296,7 +343,7 @@ def encode(cls, value): class CompactString(String): - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[str]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -305,18 +352,18 @@ def decode(self, data): raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) - value = str(value).encode(self.encoding) - return UnsignedVarInt32.encode(len(value) + 1) + value + encoded_value = str(value).encode(self.encoding) + return UnsignedVarInt32.encode(len(encoded_value) + 1) + encoded_value -class TaggedFields(AbstractType): +class TaggedFields(AbstractType[Dict[int, bytes]]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Dict[int, bytes]: num_fields = UnsignedVarInt32.decode(data) - ret = {} + ret: Dict[int, bytes] = {} if not num_fields: return ret prev_tag = -1 @@ -331,20 +378,20 @@ def decode(cls, data): return ret @classmethod - def encode(cls, value): + def encode(cls, value: Dict[int, bytes]) -> bytes: ret = UnsignedVarInt32.encode(len(value)) for k, v in value.items(): # do we allow for other data types ?? It could get complicated really fast - assert isinstance(v, bytes), f"Value {v} is not a byte array" + assert isinstance(v, bytes), f"Value {v!r} is not a byte array" assert isinstance(k, int) and k > 0, f"Key {k} is not a positive integer" ret += UnsignedVarInt32.encode(k) ret += v return ret -class CompactBytes(AbstractType): +class CompactBytes(AbstractType[Optional[bytes]]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -354,7 +401,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) else: @@ -362,7 +409,7 @@ def encode(cls, value): class CompactArray(Array): - def encode(self, items): + def encode(self, items: Optional[Sequence[Any]]) -> bytes: if items is None: return UnsignedVarInt32.encode(0) encoded_items = (self.array_of.encode(item) for item in items) @@ -370,7 +417,7 @@ def encode(self, items): (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), ) - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[List[Union[Any, Tuple[Any, ...]]]]: length = UnsignedVarInt32.decode(data) - 1 if length == -1: return None diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 56680f07..01c8a0dc 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -24,7 +24,7 @@ def test_create_message(): payload = b"test" key = b"key" - msg = Message(payload, key=key) + msg = Message(value=payload, key=key) assert msg.magic == 0 assert msg.attributes == 0 assert msg.key == key @@ -32,7 +32,7 @@ def test_create_message(): def test_encode_message_v0(): - message = Message(b"test", key=b"key") + message = Message(value=b"test", key=b"key") encoded = message.encode() expect = b"".join( [ @@ -48,7 +48,7 @@ def test_encode_message_v0(): def test_encode_message_v1(): - message = Message(b"test", key=b"key", magic=1, timestamp=1234) + message = Message(value=b"test", key=b"key", magic=1, timestamp=1234) encoded = message.encode() expect = b"".join( [ @@ -76,7 +76,7 @@ def test_decode_message(): ] ) decoded_message = Message.decode(encoded) - msg = Message(b"test", key=b"key") + msg = Message(value=b"test", key=b"key") msg.encode() # crc is recalculated during encoding assert decoded_message == msg @@ -110,7 +110,7 @@ def test_decode_message_validate_crc(): def test_encode_message_set(): - messages = [Message(b"v1", key=b"k1"), Message(b"v2", key=b"k2")] + messages = [Message(value=b"v1", key=b"k1"), Message(value=b"v2", key=b"k2")] encoded = MessageSet.encode([(0, msg.encode()) for msg in messages]) expect = b"".join( [ @@ -166,12 +166,12 @@ def test_decode_message_set(): returned_offset2, message2_size, decoded_message2 = msg2 assert returned_offset1 == 0 - message1 = Message(b"v1", key=b"k1") + message1 = Message(value=b"v1", key=b"k1") message1.encode() assert decoded_message1 == message1 assert returned_offset2 == 1 - message2 = Message(b"v2", key=b"k2") + message2 = Message(value=b"v2", key=b"k2") message2.encode() assert decoded_message2 == message2 @@ -222,7 +222,7 @@ def test_decode_message_set_partial(): returned_offset2, message2_size, decoded_message2 = msg2 assert returned_offset1 == 0 - message1 = Message(b"v1", key=b"k1") + message1 = Message(value=b"v1", key=b"k1") message1.encode() assert decoded_message1 == message1