From f061bb39dee996db7f2e14a1e089da6f6c295f4b Mon Sep 17 00:00:00 2001 From: Chengxu Bian Date: Wed, 3 Apr 2024 02:15:09 +0000 Subject: [PATCH 1/2] stash save --- aiokafka/protocol/abstract.py | 15 +++-- aiokafka/protocol/message.py | 19 ++++-- aiokafka/protocol/types.py | 112 +++++++++++++++++++--------------- 3 files changed, 87 insertions(+), 59 deletions(-) diff --git a/aiokafka/protocol/abstract.py b/aiokafka/protocol/abstract.py index 117d058e..1b89f09e 100644 --- a/aiokafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -1,15 +1,22 @@ import abc +import io +from typing import Generic, Optional, TypeVar +from typing_extensions import TypeAlias -class AbstractType(metaclass=abc.ABCMeta): +T = TypeVar("T") +RawData: TypeAlias = io.BytesIO + + +class AbstractType(Generic[T], metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def encode(cls, value): ... + def encode(self, value: Optional[T]) -> bytes: ... @classmethod @abc.abstractmethod - def decode(cls, data): ... + def decode(self, data: RawData) -> Optional[T]: ... @classmethod - def repr(cls, value): + def repr(self, value: T) -> str: return repr(value) diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index 31993fe6..d2e038e4 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,6 +1,7 @@ import io import time from binascii import crc32 +from typing import Optional from aiokafka.codec import ( gzip_decode, @@ -47,7 +48,15 @@ class Message(Struct): 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) ) - def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None): + def __init__( + self, + value: Optional[bytes], + key: Optional[bytes] = None, + magic: int = 0, + attributes: int = 0, + crc: int = 0, + timestamp: Optional[int] = None, + ): assert value is None or isinstance(value, bytes), "value must be bytes" assert key is None or isinstance(key, bytes), "key must be bytes" assert magic > 0 or timestamp is None, "timestamp not supported in v0" @@ -64,7 +73,7 @@ def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None self.value = value @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[int]: """0 for CreateTime; 1 for LogAppendTime; None if unsupported. Value is determined by broker; produced messages should always set to 0 @@ -77,7 +86,7 @@ def timestamp_type(self): else: return 0 - def encode(self, recalc_crc=True): + def encode(self, recalc_crc: bool = True): version = self.magic if version == 1: fields = ( @@ -125,7 +134,7 @@ def decode(cls, data): msg._validated_crc = _validated_crc return msg - def validate_crc(self): + def validate_crc(self) -> bool: if self._validated_crc is None: raw_msg = self.encode(recalc_crc=False) self._validated_crc = crc32(raw_msg[4:]) @@ -133,7 +142,7 @@ def validate_crc(self): return True return False - def is_compressed(self): + def is_compressed(self) -> bool: return self.attributes & self.CODEC_MASK != 0 def decompress(self): diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 7eadf7fb..4b5b6441 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import struct from struct import error +from typing import Callable, Iterable, Optional, Tuple, TypeVar, Union + +from _typeshed import ReadableBuffer + +from .abstract import AbstractType, RawData -from .abstract import AbstractType +T = TypeVar("T") -def _pack(f, value): +def _pack(f: Callable[[T], bytes], value: T) -> bytes: try: return f(value) except error as e: @@ -14,7 +21,7 @@ def _pack(f, value): ) from e -def _unpack(f, data): +def _unpack(f: Callable[[ReadableBuffer], tuple[T, ...]], data: ReadableBuffer) -> T: try: (value,) = f(data) except error as e: @@ -26,95 +33,95 @@ def _unpack(f, data): return value -class Int8(AbstractType): +class Int8(AbstractType[int]): _pack = struct.Struct(">b").pack _unpack = struct.Struct(">b").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> int: return _unpack(cls._unpack, data.read(1)) -class Int16(AbstractType): +class Int16(AbstractType[int]): _pack = struct.Struct(">h").pack _unpack = struct.Struct(">h").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> int: return _unpack(cls._unpack, data.read(2)) -class Int32(AbstractType): +class Int32(AbstractType[int]): _pack = struct.Struct(">i").pack _unpack = struct.Struct(">i").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> int: return _unpack(cls._unpack, data.read(4)) -class UInt32(AbstractType): +class UInt32(AbstractType[int]): _pack = struct.Struct(">I").pack _unpack = struct.Struct(">I").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> int: return _unpack(cls._unpack, data.read(4)) -class Int64(AbstractType): +class Int64(AbstractType[int]): _pack = struct.Struct(">q").pack _unpack = struct.Struct(">q").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> int: return _unpack(cls._unpack, data.read(8)) -class Float64(AbstractType): +class Float64(AbstractType[float]): _pack = struct.Struct(">d").pack _unpack = struct.Struct(">d").unpack @classmethod - def encode(cls, value): + def encode(cls, value: float) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> float: return _unpack(cls._unpack, data.read(8)) -class String(AbstractType): - def __init__(self, encoding="utf-8"): +class String(AbstractType[str]): + def __init__(self, encoding="utf-8") -> None: self.encoding = encoding - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return Int16.encode(-1) value = str(value).encode(self.encoding) return Int16.encode(len(value)) + value - def decode(self, data): + def decode(self, data: RawData) -> str: length = Int16.decode(data) if length < 0: return None @@ -124,16 +131,16 @@ def decode(self, data): return value.decode(self.encoding) -class Bytes(AbstractType): +class Bytes(AbstractType[bytes]): @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return Int32.encode(-1) else: return Int32.encode(len(value)) + value @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> Optional[bytes]: length = Int32.decode(data) if length < 0: return None @@ -143,33 +150,36 @@ def decode(cls, data): return value @classmethod - def repr(cls, value): + def repr(cls, value: Optional[bytes]) -> str: return repr( value[:100] + b"..." if value is not None and len(value) > 100 else value ) -class Boolean(AbstractType): +class Boolean(AbstractType[bool]): _pack = struct.Struct(">?").pack _unpack = struct.Struct(">?").unpack @classmethod - def encode(cls, value): + def encode(cls, value: bool) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> bool: return _unpack(cls._unpack, data.read(1)) class Schema(AbstractType): - def __init__(self, *fields): + names: Tuple[str] + fields: Tuple[AbstractType] + + def __init__(self, *fields: Tuple[str, AbstractType]): if fields: self.names, self.fields = zip(*fields) else: self.names, self.fields = (), () - def encode(self, item): + def encode(self, item) -> bytes: if len(item) != len(self.fields): raise ValueError("Item field count does not match Schema") return b"".join(field.encode(item[i]) for i, field in enumerate(self.fields)) @@ -177,10 +187,10 @@ def encode(self, item): def decode(self, data): return tuple(field.decode(data) for field in self.fields) - def __len__(self): + def __len__(self) -> int: return len(self.fields) - def repr(self, value): + def repr(self, value) -> str: key_vals = [] try: for i in range(len(self)): @@ -195,7 +205,9 @@ def repr(self, value): class Array(AbstractType): - def __init__(self, *array_of): + array_of: Union[Schema, AbstractType] + + def __init__(self, *array_of: Tuple[str, AbstractType]): if len(array_of) > 1: self.array_of = Schema(*array_of) elif len(array_of) == 1 and ( @@ -206,7 +218,7 @@ def __init__(self, *array_of): else: raise ValueError("Array instantiated with no array_of type") - def encode(self, items): + def encode(self, items: Optional[Iterable[Tuple[str, AbstractType]]]) -> bytes: if items is None: return Int32.encode(-1) encoded_items = (self.array_of.encode(item) for item in items) @@ -214,13 +226,13 @@ def encode(self, items): (Int32.encode(len(items)), *encoded_items), ) - def decode(self, data): + def decode(self, data: RawData) -> Optional[list[AbstractType]]: length = Int32.decode(data) if length == -1: return None return [self.array_of.decode(data) for _ in range(length)] - def repr(self, list_of_items): + def repr(self, list_of_items: Optional[list[AbstractType]]) -> str: if list_of_items is None: return "NULL" return "[" + ", ".join(self.array_of.repr(item) for item in list_of_items) + "]" @@ -242,7 +254,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: value &= 0xFFFFFFFF ret = b"" while (value & 0xFFFFFF80) != 0: @@ -260,7 +272,7 @@ def decode(cls, data): return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFF return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) @@ -282,7 +294,7 @@ def decode(cls, data): return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFFFFFFFFFF v = (value << 1) ^ (value >> 63) @@ -296,7 +308,7 @@ def encode(cls, value): class CompactString(String): - def decode(self, data): + def decode(self, data: RawData) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -305,7 +317,7 @@ def decode(self, data): raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) value = str(value).encode(self.encoding) @@ -331,7 +343,7 @@ def decode(cls, data): return ret @classmethod - def encode(cls, value): + def encode(cls, value) -> bytes: ret = UnsignedVarInt32.encode(len(value)) for k, v in value.items(): # do we allow for other data types ?? It could get complicated really fast @@ -344,7 +356,7 @@ def encode(cls, value): class CompactBytes(AbstractType): @classmethod - def decode(cls, data): + def decode(cls, data: RawData) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -354,7 +366,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) else: @@ -362,7 +374,7 @@ def encode(cls, value): class CompactArray(Array): - def encode(self, items): + def encode(self, items: Optional[list[AbstractType]]) -> bytes: if items is None: return UnsignedVarInt32.encode(0) encoded_items = (self.array_of.encode(item) for item in items) @@ -370,7 +382,7 @@ def encode(self, items): (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), ) - def decode(self, data): + def decode(self, data: RawData) -> Optional[list[Optional[AbstractType]]]: length = UnsignedVarInt32.decode(data) - 1 if length == -1: return None From 593d324c161f5136896d19bb3b941059232930e4 Mon Sep 17 00:00:00 2001 From: Chengxu Bian Date: Mon, 8 Apr 2024 21:15:46 -0400 Subject: [PATCH 2/2] add basic type hints --- aiokafka/protocol/abstract.py | 4 ++-- aiokafka/protocol/types.py | 43 ++++++++++++++++------------------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/aiokafka/protocol/abstract.py b/aiokafka/protocol/abstract.py index 1b89f09e..953a76c3 100644 --- a/aiokafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -5,7 +5,7 @@ from typing_extensions import TypeAlias T = TypeVar("T") -RawData: TypeAlias = io.BytesIO +BytesIO: TypeAlias = io.BytesIO class AbstractType(Generic[T], metaclass=abc.ABCMeta): @@ -15,7 +15,7 @@ def encode(self, value: Optional[T]) -> bytes: ... @classmethod @abc.abstractmethod - def decode(self, data: RawData) -> Optional[T]: ... + def decode(self, data: BytesIO) -> Optional[T]: ... @classmethod def repr(self, value: T) -> str: diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 4b5b6441..03aaf554 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -2,11 +2,11 @@ import struct from struct import error -from typing import Callable, Iterable, Optional, Tuple, TypeVar, Union +from typing import Callable, Optional, TypeVar from _typeshed import ReadableBuffer -from .abstract import AbstractType, RawData +from .abstract import AbstractType, BytesIO T = TypeVar("T") @@ -42,7 +42,7 @@ def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: RawData) -> int: + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(1)) @@ -55,7 +55,7 @@ def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: RawData) -> int: + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(2)) @@ -68,7 +68,7 @@ def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: RawData) -> int: + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) @@ -81,7 +81,7 @@ def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: RawData) -> int: + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) @@ -94,7 +94,7 @@ def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: RawData) -> int: + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(8)) @@ -107,7 +107,7 @@ def encode(cls, value: float) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: RawData) -> float: + def decode(cls, data: BytesIO) -> float: return _unpack(cls._unpack, data.read(8)) @@ -121,7 +121,7 @@ def encode(self, value: Optional[str]) -> bytes: value = str(value).encode(self.encoding) return Int16.encode(len(value)) + value - def decode(self, data: RawData) -> str: + def decode(self, data: BytesIO) -> Optional[str]: length = Int16.decode(data) if length < 0: return None @@ -140,7 +140,7 @@ def encode(cls, value: Optional[bytes]) -> bytes: return Int32.encode(len(value)) + value @classmethod - def decode(cls, data: RawData) -> Optional[bytes]: + def decode(cls, data: BytesIO) -> Optional[bytes]: length = Int32.decode(data) if length < 0: return None @@ -165,15 +165,13 @@ def encode(cls, value: bool) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data: RawData) -> bool: + def decode(cls, data: BytesIO) -> bool: return _unpack(cls._unpack, data.read(1)) class Schema(AbstractType): - names: Tuple[str] - fields: Tuple[AbstractType] - def __init__(self, *fields: Tuple[str, AbstractType]): + def __init__(self, *fields): if fields: self.names, self.fields = zip(*fields) else: @@ -205,9 +203,8 @@ def repr(self, value) -> str: class Array(AbstractType): - array_of: Union[Schema, AbstractType] - def __init__(self, *array_of: Tuple[str, AbstractType]): + def __init__(self, *array_of): if len(array_of) > 1: self.array_of = Schema(*array_of) elif len(array_of) == 1 and ( @@ -218,7 +215,7 @@ def __init__(self, *array_of: Tuple[str, AbstractType]): else: raise ValueError("Array instantiated with no array_of type") - def encode(self, items: Optional[Iterable[Tuple[str, AbstractType]]]) -> bytes: + def encode(self, items) -> bytes: if items is None: return Int32.encode(-1) encoded_items = (self.array_of.encode(item) for item in items) @@ -226,7 +223,7 @@ def encode(self, items: Optional[Iterable[Tuple[str, AbstractType]]]) -> bytes: (Int32.encode(len(items)), *encoded_items), ) - def decode(self, data: RawData) -> Optional[list[AbstractType]]: + def decode(self, data: BytesIO) -> Optional[list[AbstractType]]: length = Int32.decode(data) if length == -1: return None @@ -308,7 +305,7 @@ def encode(cls, value) -> bytes: class CompactString(String): - def decode(self, data: RawData) -> Optional[bytes]: + def decode(self, data: BytesIO) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -326,7 +323,7 @@ def encode(self, value: Optional[str]) -> bytes: class TaggedFields(AbstractType): @classmethod - def decode(cls, data): + def decode(cls, data: bytes): num_fields = UnsignedVarInt32.decode(data) ret = {} if not num_fields: @@ -356,7 +353,7 @@ def encode(cls, value) -> bytes: class CompactBytes(AbstractType): @classmethod - def decode(cls, data: RawData) -> Optional[bytes]: + def decode(cls, data: BytesIO) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -374,7 +371,7 @@ def encode(cls, value: Optional[bytes]) -> bytes: class CompactArray(Array): - def encode(self, items: Optional[list[AbstractType]]) -> bytes: + def encode(self, items) -> bytes: if items is None: return UnsignedVarInt32.encode(0) encoded_items = (self.array_of.encode(item) for item in items) @@ -382,7 +379,7 @@ def encode(self, items: Optional[list[AbstractType]]) -> bytes: (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), ) - def decode(self, data: RawData) -> Optional[list[Optional[AbstractType]]]: + def decode(self, data: BytesIO): length = UnsignedVarInt32.decode(data) - 1 if length == -1: return None