Skip to content

Commit

Permalink
add typing to aiokafka/coordinator/* (#1006)
Browse files Browse the repository at this point in the history
* add typing to aiokafka/record/*

* add some annotations to tests/record

* fix almost all errors

* test w/o protocols

* Revert "test w/o protocols"

This reverts commit 7fa1efa.

* use TypeIs

* use dataclass

* remove timestamp/timestamp_type from cython DefaultRecord

* sync cython stubs with code

* simplify types

* add typing to aiokafka/coordinator/*

* fix review

* fix format

* fix review

* fix type errors

* fix review

* fix review

* assert consumer is not None

* fix review (continue is consumer is None)
  • Loading branch information
dimastbk authored Jun 29, 2024
1 parent 14aa358 commit 4cba502
Show file tree
Hide file tree
Showing 11 changed files with 382 additions and 179 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION)
DIFF_BRANCH=origin/master
FORMATTED_AREAS=\
aiokafka/codec.py \
aiokafka/coordinator/ \
aiokafka/errors.py \
aiokafka/helpers.py \
aiokafka/structs.py \
Expand All @@ -17,6 +18,7 @@ FORMATTED_AREAS=\
tests/test_helpers.py \
tests/test_protocol.py \
tests/test_protocol_object_conversion.py \
tests/coordinator/ \
tests/record/

.PHONY: setup
Expand Down
3 changes: 2 additions & 1 deletion aiokafka/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import threading
import time
from concurrent.futures import Future
from typing import Optional, Set

from aiokafka import errors as Errors
from aiokafka.conn import collect_hosts
Expand Down Expand Up @@ -103,7 +104,7 @@ def broker_metadata(self, broker_id):
or self._coordinator_brokers.get(broker_id)
)

def partitions_for_topic(self, topic):
def partitions_for_topic(self, topic: str) -> Optional[Set[int]]:
"""Return set of all partitions for topic (whether available or not)
Arguments:
Expand Down
27 changes: 21 additions & 6 deletions aiokafka/coordinator/assignors/abstract.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
import abc
import logging
from typing import Dict, Iterable, Mapping

from aiokafka.cluster import ClusterMetadata
from aiokafka.coordinator.protocol import (
ConsumerProtocolMemberAssignment,
ConsumerProtocolMemberMetadata,
)

log = logging.getLogger(__name__)


class AbstractPartitionAssignor:
class AbstractPartitionAssignor(abc.ABC):
"""Abstract assignor implementation which does some common grunt work (in particular
collecting partition counts which are always needed in assignors).
"""

@abc.abstractproperty
def name(self):
@property
@abc.abstractmethod
def name(self) -> str:
""".name should be a string identifying the assignor"""

@classmethod
@abc.abstractmethod
def assign(self, cluster, members):
def assign(
cls,
cluster: ClusterMetadata,
members: Mapping[str, ConsumerProtocolMemberMetadata],
) -> Dict[str, ConsumerProtocolMemberAssignment]:
"""Perform group assignment given cluster metadata and member subscriptions
Arguments:
Expand All @@ -26,8 +39,9 @@ def assign(self, cluster, members):
dict: {member_id: MemberAssignment}
"""

@classmethod
@abc.abstractmethod
def metadata(self, topics):
def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata:
"""Generate ProtocolMetadata to be submitted via JoinGroupRequest.
Arguments:
Expand All @@ -37,8 +51,9 @@ def metadata(self, topics):
MemberMetadata struct
"""

@classmethod
@abc.abstractmethod
def on_assignment(self, assignment):
def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None:
"""Callback that runs on each assignment.
This method can be used to update internal state, if any, of the
Expand Down
30 changes: 18 additions & 12 deletions aiokafka/coordinator/assignors/range.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import logging
from typing import Dict, Iterable, List, Mapping

from aiokafka.cluster import ClusterMetadata
from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor
from aiokafka.coordinator.protocol import (
ConsumerProtocolMemberAssignment,
Expand Down Expand Up @@ -32,45 +34,49 @@ class RangePartitionAssignor(AbstractPartitionAssignor):
version = 0

@classmethod
def assign(cls, cluster, member_metadata):
consumers_per_topic = collections.defaultdict(list)
for member, metadata in member_metadata.items():
def assign(
cls,
cluster: ClusterMetadata,
members: Mapping[str, ConsumerProtocolMemberMetadata],
) -> Dict[str, ConsumerProtocolMemberAssignment]:
consumers_per_topic: Dict[str, List[str]] = collections.defaultdict(list)
for member, metadata in members.items():
for topic in metadata.subscription:
consumers_per_topic[topic].append(member)

# construct {member_id: {topic: [partition, ...]}}
assignment = collections.defaultdict(dict)
assignment: Dict[str, Dict[str, List[int]]] = collections.defaultdict(dict)

for topic, consumers_for_topic in consumers_per_topic.items():
partitions = cluster.partitions_for_topic(topic)
if partitions is None:
log.warning("No partition metadata for topic %s", topic)
continue
partitions = sorted(partitions)
partitions_list = sorted(partitions)
consumers_for_topic.sort()

partitions_per_consumer = len(partitions) // len(consumers_for_topic)
consumers_with_extra = len(partitions) % len(consumers_for_topic)
partitions_per_consumer = len(partitions_list) // len(consumers_for_topic)
consumers_with_extra = len(partitions_list) % len(consumers_for_topic)

for i, member in enumerate(consumers_for_topic):
start = partitions_per_consumer * i
start += min(i, consumers_with_extra)
length = partitions_per_consumer
if not i + 1 > consumers_with_extra:
length += 1
assignment[member][topic] = partitions[start : start + length]
assignment[member][topic] = partitions_list[start : start + length]

protocol_assignment = {}
for member_id in member_metadata:
protocol_assignment: Dict[str, ConsumerProtocolMemberAssignment] = {}
for member_id in members:
protocol_assignment[member_id] = ConsumerProtocolMemberAssignment(
cls.version, sorted(assignment[member_id].items()), b""
)
return protocol_assignment

@classmethod
def metadata(cls, topics):
def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata:
return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"")

@classmethod
def on_assignment(cls, assignment):
def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None:
pass
26 changes: 17 additions & 9 deletions aiokafka/coordinator/assignors/roundrobin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import collections
import itertools
import logging
from typing import Dict, Iterable, List, Mapping

from aiokafka.cluster import ClusterMetadata
from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor
from aiokafka.coordinator.protocol import (
ConsumerProtocolMemberAssignment,
Expand Down Expand Up @@ -49,12 +51,16 @@ class RoundRobinPartitionAssignor(AbstractPartitionAssignor):
version = 0

@classmethod
def assign(cls, cluster, member_metadata):
def assign(
cls,
cluster: ClusterMetadata,
members: Mapping[str, ConsumerProtocolMemberMetadata],
) -> Dict[str, ConsumerProtocolMemberAssignment]:
all_topics = set()
for metadata in member_metadata.values():
for metadata in members.values():
all_topics.update(metadata.subscription)

all_topic_partitions = []
all_topic_partitions: List[TopicPartition] = []
for topic in all_topics:
partitions = cluster.partitions_for_topic(topic)
if partitions is None:
Expand All @@ -66,31 +72,33 @@ def assign(cls, cluster, member_metadata):
all_topic_partitions.sort()

# construct {member_id: {topic: [partition, ...]}}
assignment = collections.defaultdict(lambda: collections.defaultdict(list))
assignment: Dict[str, Dict[str, List[int]]] = collections.defaultdict(
lambda: collections.defaultdict(list)
)

member_iter = itertools.cycle(sorted(member_metadata.keys()))
member_iter = itertools.cycle(sorted(members.keys()))
for partition in all_topic_partitions:
member_id = next(member_iter)

# Because we constructed all_topic_partitions from the set of
# member subscribed topics, we should be safe assuming that
# each topic in all_topic_partitions is in at least one member
# subscription; otherwise this could yield an infinite loop
while partition.topic not in member_metadata[member_id].subscription:
while partition.topic not in members[member_id].subscription:
member_id = next(member_iter)
assignment[member_id][partition.topic].append(partition.partition)

protocol_assignment = {}
for member_id in member_metadata:
for member_id in members:
protocol_assignment[member_id] = ConsumerProtocolMemberAssignment(
cls.version, sorted(assignment[member_id].items()), b""
)
return protocol_assignment

@classmethod
def metadata(cls, topics):
def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata:
return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"")

@classmethod
def on_assignment(cls, assignment):
def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None:
pass
49 changes: 33 additions & 16 deletions aiokafka/coordinator/assignors/sticky/partition_movements.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import logging
from collections import defaultdict, namedtuple
from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, List, NamedTuple, Sequence, Set, Tuple

from aiokafka.structs import TopicPartition

log = logging.getLogger(__name__)


ConsumerPair = namedtuple("ConsumerPair", ["src_member_id", "dst_member_id"])
class ConsumerPair(NamedTuple):
src_member_id: str
dst_member_id: str


"""
Represents a pair of Kafka consumer ids involved in a partition reassignment.
Each ConsumerPair corresponds to a particular partition or topic, indicates that the
Expand All @@ -16,7 +23,7 @@
"""


def is_sublist(source, target):
def is_sublist(source: Sequence[Any], target: Sequence[Any]) -> bool:
"""Checks if one list is a sublist of another.
Arguments:
Expand All @@ -40,11 +47,13 @@ class PartitionMovements:
form a ConsumerPair object) for each partition.
"""

def __init__(self):
self.partition_movements_by_topic = defaultdict(lambda: defaultdict(set))
self.partition_movements = {}
def __init__(self) -> None:
self.partition_movements_by_topic: Dict[str, Dict[ConsumerPair, Set[TopicPartition]]] = defaultdict(lambda: defaultdict(set)) # fmt: skip # noqa: E501
self.partition_movements: Dict[TopicPartition, ConsumerPair] = {}

def move_partition(self, partition, old_consumer, new_consumer):
def move_partition(
self, partition: TopicPartition, old_consumer: str, new_consumer: str
) -> None:
pair = ConsumerPair(src_member_id=old_consumer, dst_member_id=new_consumer)
if partition in self.partition_movements:
# this partition has previously moved
Expand All @@ -62,7 +71,9 @@ def move_partition(self, partition, old_consumer, new_consumer):
else:
self._add_partition_movement_record(partition, pair)

def get_partition_to_be_moved(self, partition, old_consumer, new_consumer):
def get_partition_to_be_moved(
self, partition: TopicPartition, old_consumer: str, new_consumer: str
) -> TopicPartition:
if partition.topic not in self.partition_movements_by_topic:
return partition
if partition in self.partition_movements:
Expand All @@ -79,7 +90,7 @@ def get_partition_to_be_moved(self, partition, old_consumer, new_consumer):
iter(self.partition_movements_by_topic[partition.topic][reverse_pair])
)

def are_sticky(self):
def are_sticky(self) -> bool:
for topic, movements in self.partition_movements_by_topic.items():
movement_pairs = set(movements.keys())
if self._has_cycles(movement_pairs):
Expand All @@ -93,7 +104,9 @@ def are_sticky(self):
return False
return True

def _remove_movement_record_of_partition(self, partition):
def _remove_movement_record_of_partition(
self, partition: TopicPartition
) -> ConsumerPair:
pair = self.partition_movements[partition]
del self.partition_movements[partition]

Expand All @@ -105,16 +118,18 @@ def _remove_movement_record_of_partition(self, partition):

return pair

def _add_partition_movement_record(self, partition, pair):
def _add_partition_movement_record(
self, partition: TopicPartition, pair: ConsumerPair
) -> None:
self.partition_movements[partition] = pair
self.partition_movements_by_topic[partition.topic][pair].add(partition)

def _has_cycles(self, consumer_pairs):
cycles = set()
def _has_cycles(self, consumer_pairs: Set[ConsumerPair]) -> bool:
cycles: Set[Tuple[str, ...]] = set()
for pair in consumer_pairs:
reduced_pairs = deepcopy(consumer_pairs)
reduced_pairs.remove(pair)
path = [pair.src_member_id]
path: List[str] = [pair.src_member_id]
if self._is_linked(
pair.dst_member_id, pair.src_member_id, reduced_pairs, path
) and not self._is_subcycle(path, cycles):
Expand All @@ -132,7 +147,7 @@ def _has_cycles(self, consumer_pairs):
)

@staticmethod
def _is_subcycle(cycle, cycles):
def _is_subcycle(cycle: List[str], cycles: Set[Tuple[str, ...]]) -> bool:
super_cycle = deepcopy(cycle)
super_cycle = super_cycle[:-1]
super_cycle.extend(cycle)
Expand All @@ -141,7 +156,9 @@ def _is_subcycle(cycle, cycles):
return True
return False

def _is_linked(self, src, dst, pairs, current_path):
def _is_linked(
self, src: str, dst: str, pairs: Set[ConsumerPair], current_path: List[str]
) -> bool:
if src == dst:
return False
if not pairs:
Expand Down
Loading

0 comments on commit 4cba502

Please sign in to comment.